Page boundary detection in historical documents
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

120 lignes
2.8KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import argparse
  5. import numpy as np
  6. import caffe
  7. import cv2
  8. from process_pixel_labels import post_process
  9. NET_FILE = './models/cbad_train_val.prototxt'
  10. WEIGHT_FILE = './models/cbad_weights.caffemodel'
  11. def safe_mkdir(_dir):
  12. try:
  13. os.makedirs(_dir)
  14. except:
  15. pass
  16. def predict(network, im, output_blob, args):
  17. network.blobs["data"].data[0,:,:,:] = im
  18. network.forward()
  19. if args.model == 'ohio':
  20. # sigmoid
  21. response = network.blobs[output_blob].data[0,0,:,:].copy()
  22. response[response >= 0.5] = 1
  23. response[response <= 0.5] = 0
  24. return response
  25. else:
  26. # softmax
  27. response = network.blobs[output_blob].data[0,:].copy()
  28. return np.argmax(response, axis=0)
  29. def presolve(net, args):
  30. net.blobs["data"].reshape(1, 3, 256, 256)
  31. net.blobs["gt"].reshape(1, 1, 256, 256)
  32. def main(args):
  33. net = caffe.Net(NET_FILE, WEIGHT_FILE, caffe.TEST)
  34. presolve(net, args)
  35. file_list = map(lambda s: s.strip(), open(args.manifest, 'r').readlines())
  36. fd = open(args.out_file, 'w')
  37. for idx, line in enumerate(file_list):
  38. if idx % args.print_count == 0:
  39. print "Processed %d/%d Images" % (idx, len(file_list))
  40. tokens = line.split(',')
  41. f = tokens[0]
  42. resolved = os.path.join(args.image_dir, f)
  43. im = cv2.imread(resolved, 1)
  44. height = im.shape[0]
  45. width = im.shape[1]
  46. fd.write('%s,' % f)
  47. _input = 0.0039 * (cv2.resize(im, (256, 256)) - 127.)
  48. _input = np.transpose(_input, (2, 0, 1))
  49. raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8)
  50. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png")
  51. cv2.imwrite(out_fn, raw)
  52. post, coords = post_process(raw)
  53. for idx2 in [1, 2, 3, 0]:
  54. fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.))
  55. fd.write('\n')
  56. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png")
  57. cv2.imwrite(out_fn, post)
  58. def get_args():
  59. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  60. parser.add_argument("image_dir",
  61. help="The directory where images are stored")
  62. parser.add_argument("manifest",
  63. help="txt file listing images relative to image_dir")
  64. parser.add_argument("model",
  65. help="[cbad|ohio]")
  66. parser.add_argument("out_file", type=str,
  67. help="Output file")
  68. parser.add_argument("--out-dir", type=str, default='out',
  69. help="")
  70. parser.add_argument("--gpu", type=int, default=0,
  71. help="GPU to use for running the network")
  72. parser.add_argument("--print-count", default=10, type=int,
  73. help="Print interval")
  74. args = parser.parse_args()
  75. print args
  76. return args
  77. if __name__ == "__main__":
  78. args = get_args()
  79. safe_mkdir(args.out_dir)
  80. if args.model == 'ohio':
  81. NET_FILE = './models/ohio_train_val.prototxt'
  82. WEIGHT_FILE = './models/ohio_weights.caffemodel'
  83. if args.gpu >= 0:
  84. caffe.set_device(args.gpu)
  85. caffe.set_mode_gpu()
  86. else:
  87. caffe.set_mode_cpu()
  88. main(args)