Page boundary detection in historical documents
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.0KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import argparse
  5. import numpy as np
  6. import caffe
  7. import cv2
  8. import random
  9. def safe_mkdir(_dir):
  10. try:
  11. os.makedirs(_dir)
  12. except:
  13. pass
  14. def predict(network, im, output_blob, args):
  15. network.blobs["data"].data[0,:,:,:] = im
  16. network.forward()
  17. response = network.blobs[output_blob].data[0,:].copy()
  18. return np.argmax(response, axis=0)
  19. def presolve(net, args):
  20. net.blobs["data"].reshape(1, 3 if args.color else 1, args.image_size, args.image_size)
  21. net.blobs["gt"].reshape(1, 1, args.image_size, args.image_size)
  22. def main(args):
  23. net = caffe.Net(args.net_file, args.weight_file, caffe.TEST)
  24. presolve(net, args)
  25. file_list = map(lambda s: s.strip(), open(args.test_manifest, 'r').readlines())
  26. fd = open(args.out_file, 'w')
  27. for idx, line in enumerate(file_list):
  28. if idx % args.print_count == 0:
  29. print "Processed %d/%d Images" % (idx, len(file_list))
  30. tokens = line.split(',')
  31. f = tokens[0]
  32. resolved = os.path.join(args.dataset_dir, f)
  33. im = cv2.imread(resolved, 1 if args.color else 0)
  34. _input = args.scale * (cv2.resize(im, (args.image_size, args.image_size)) - args.mean)
  35. if _input.ndim > 2:
  36. _input = np.transpose(_input, (2, 0, 1))
  37. raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8)
  38. if args.out_dir:
  39. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png")
  40. cv2.imwrite(out_fn, raw)
  41. post, coords = post_process(raw)
  42. for idx2 in [1, 2, 3, 0]:
  43. fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.))
  44. fd.write('\n')
  45. if args.out_dir:
  46. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png")
  47. cv2.imwrite(out_fn, post)
  48. def get_args():
  49. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  50. parser.add_argument("net_file",
  51. help="The deploy.prototxt")
  52. parser.add_argument("weight_file",
  53. help="The .caffemodel")
  54. parser.add_argument("dataset_dir",
  55. help="The dataset to be evaluated")
  56. parser.add_argument("test_manifest",
  57. help="Images to predict")
  58. parser.add_argument("out_file",
  59. help="output file listing quad regions")
  60. parser.add_argument("--out-dir", default='', type=str,
  61. help="Dump images")
  62. parser.add_argument("--gpu", type=int, default=0,
  63. help="GPU to use for running the network")
  64. parser.add_argument("-c", "--color", default=False, action='store_true',
  65. help="Training batch size")
  66. parser.add_argument("-m", "--mean", type=float, default=127.,
  67. help="Mean value for data preprocessing")
  68. parser.add_argument("-s", "--scale", type=float, default=0.0039,
  69. help="Optional pixel scale factor")
  70. parser.add_argument("--image-size", default=256, type=int,
  71. help="Size of images for input to prediction")
  72. parser.add_argument("--print-count", default=10, type=int,
  73. help="Print interval")
  74. args = parser.parse_args()
  75. print args
  76. return args
  77. if __name__ == "__main__":
  78. args = get_args()
  79. safe_mkdir(args.out_dir)
  80. if args.gpu >= 0:
  81. caffe.set_device(args.gpu)
  82. caffe.set_mode_gpu()
  83. else:
  84. caffe.set_mode_cpu()
  85. main(args)