Page boundary detection in historical documents
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
2.9KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import argparse
  5. import numpy as np
  6. import caffe
  7. import cv2
  8. from process_pixel_labels import post_process
  9. NET_FILE = './models/cbad_train_val.prototxt'
  10. WEIGHT_FILE = './models/cbad_weights.caffemodel'
  11. def safe_mkdir(_dir):
  12. try:
  13. os.makedirs(_dir)
  14. except:
  15. pass
  16. def predict(network, im, output_blob, args):
  17. network.blobs["data"].data[0,:,:,:] = im
  18. network.forward()
  19. if args.model == 'ohio':
  20. # sigmoid
  21. response = network.blobs[output_blob].data[0,0,:,:].copy()
  22. response[response >= 0.5] = 1
  23. response[response <= 0.5] = 0
  24. return response
  25. else:
  26. # softmax
  27. response = network.blobs[output_blob].data[0,:].copy()
  28. return np.argmax(response, axis=0)
  29. def presolve(net, args):
  30. net.blobs["data"].reshape(1, 3, 256, 256)
  31. net.blobs["gt"].reshape(1, 1, 256, 256)
  32. def main(args):
  33. net = caffe.Net(NET_FILE, WEIGHT_FILE, caffe.TEST)
  34. presolve(net, args)
  35. file_list = [s.strip() for s in open(args.manifest, 'r').readlines()]
  36. fd = open(args.out_file, 'w')
  37. for idx, line in enumerate(file_list):
  38. if idx % args.print_count == 0:
  39. print("Processed %d/%d Images" % (idx, len(file_list)))
  40. tokens = line.split(',')
  41. f = tokens[0]
  42. resolved = os.path.join(args.image_dir, f)
  43. im = cv2.imread(resolved, 1)
  44. height = im.shape[0]
  45. width = im.shape[1]
  46. fd.write('%s,' % f)
  47. _input = 0.0039 * (cv2.resize(im, (256, 256)) - 127.)
  48. _input = np.transpose(_input, (2, 0, 1))
  49. raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8)
  50. if args.out_dir:
  51. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png")
  52. cv2.imwrite(out_fn, raw)
  53. post, coords = post_process(raw)
  54. for idx2 in [1, 2, 3, 0]:
  55. fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.))
  56. fd.write('\n')
  57. if args.out_dir:
  58. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png")
  59. cv2.imwrite(out_fn, post)
  60. def get_args():
  61. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  62. parser.add_argument("image_dir",
  63. help="The directory where images are stored")
  64. parser.add_argument("manifest",
  65. help="txt file listing images relative to image_dir")
  66. parser.add_argument("model",
  67. help="[cbad|ohio]")
  68. parser.add_argument("out_file", type=str,
  69. help="Output file")
  70. parser.add_argument("--out-dir", type=str, default='',
  71. help="")
  72. parser.add_argument("--gpu", type=int, default=0,
  73. help="GPU to use for running the network")
  74. parser.add_argument("--print-count", default=10, type=int,
  75. help="Print interval")
  76. args = parser.parse_args()
  77. print(args)
  78. return args
  79. if __name__ == "__main__":
  80. args = get_args()
  81. if args.out_dir:
  82. safe_mkdir(args.out_dir)
  83. if args.model == 'ohio':
  84. NET_FILE = './models/ohio_train_val.prototxt'
  85. WEIGHT_FILE = './models/ohio_weights.caffemodel'
  86. if args.gpu >= 0:
  87. caffe.set_device(args.gpu)
  88. caffe.set_mode_gpu()
  89. else:
  90. caffe.set_mode_cpu()
  91. main(args)