Page boundary detection in historical documents
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

115 lignes
3.0KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import argparse
  5. import numpy as np
  6. import caffe
  7. import cv2
  8. import random
  9. def safe_mkdir(_dir):
  10. try:
  11. os.makedirs(_dir)
  12. except:
  13. pass
  14. def predict(network, im, output_blob, args):
  15. network.blobs["data"].data[0,:,:,:] = im
  16. network.forward()
  17. response = network.blobs[output_blob].data[0,:].copy()
  18. return np.argmax(response, axis=0)
  19. def presolve(net, args):
  20. net.blobs["data"].reshape(1, 3 if args.color else 1, args.image_size, args.image_size)
  21. net.blobs["gt"].reshape(1, 1, args.image_size, args.image_size)
  22. def main(args):
  23. net = caffe.Net(args.net_file, args.weight_file, caffe.TEST)
  24. presolve(net, args)
  25. file_list = [s.strip() for s in open(args.test_manifest, 'r').readlines()]
  26. fd = open(args.out_file, 'w')
  27. for idx, line in enumerate(file_list):
  28. if idx % args.print_count == 0:
  29. print("Processed %d/%d Images" % (idx, len(file_list)))
  30. tokens = line.split(',')
  31. f = tokens[0]
  32. resolved = os.path.join(args.dataset_dir, f)
  33. im = cv2.imread(resolved, 1 if args.color else 0)
  34. _input = args.scale * (cv2.resize(im, (args.image_size, args.image_size)) - args.mean)
  35. if _input.ndim > 2:
  36. _input = np.transpose(_input, (2, 0, 1))
  37. raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8)
  38. if args.out_dir:
  39. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png")
  40. cv2.imwrite(out_fn, raw)
  41. post, coords = post_process(raw)
  42. for idx2 in [1, 2, 3, 0]:
  43. fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.))
  44. fd.write('\n')
  45. if args.out_dir:
  46. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png")
  47. cv2.imwrite(out_fn, post)
  48. def get_args():
  49. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  50. parser.add_argument("net_file",
  51. help="The deploy.prototxt")
  52. parser.add_argument("weight_file",
  53. help="The .caffemodel")
  54. parser.add_argument("dataset_dir",
  55. help="The dataset to be evaluated")
  56. parser.add_argument("test_manifest",
  57. help="Images to predict")
  58. parser.add_argument("out_file",
  59. help="output file listing quad regions")
  60. parser.add_argument("--out-dir", default='', type=str,
  61. help="Dump images")
  62. parser.add_argument("--gpu", type=int, default=0,
  63. help="GPU to use for running the network")
  64. parser.add_argument("-c", "--color", default=False, action='store_true',
  65. help="Training batch size")
  66. parser.add_argument("-m", "--mean", type=float, default=127.,
  67. help="Mean value for data preprocessing")
  68. parser.add_argument("-s", "--scale", type=float, default=0.0039,
  69. help="Optional pixel scale factor")
  70. parser.add_argument("--image-size", default=256, type=int,
  71. help="Size of images for input to prediction")
  72. parser.add_argument("--print-count", default=10, type=int,
  73. help="Print interval")
  74. args = parser.parse_args()
  75. print(args)
  76. return args
  77. if __name__ == "__main__":
  78. args = get_args()
  79. safe_mkdir(args.out_dir)
  80. if args.gpu >= 0:
  81. caffe.set_device(args.gpu)
  82. caffe.set_mode_gpu()
  83. else:
  84. caffe.set_mode_cpu()
  85. main(args)