Page boundary detection in historical documents
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.0KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import collections
  5. import argparse
  6. import numpy as np
  7. import matplotlib
  8. matplotlib.use("AGG")
  9. import matplotlib.pyplot as plt
  10. import caffe
  11. import cv2
  12. import random
  13. import scipy.ndimage as nd
  14. def safe_mkdir(_dir):
  15. try:
  16. os.makedirs(_dir)
  17. except:
  18. pass
  19. def predict(network, im, output_blob, args):
  20. network.blobs["data"].data[0,:,:,:] = im
  21. network.forward()
  22. #response = network.blobs[output_blob].data[0,:].copy()
  23. #return np.argmax(response, axis=0)
  24. response = network.blobs[output_blob].data[0,0,:,:].copy()
  25. response[response >= 0.5] = 1
  26. response[response <= 0.5] = 0
  27. return response
  28. def presolve(net, args):
  29. net.blobs["data"].reshape(args.batch_size, 3 if args.color else 1, args.image_size, args.image_size)
  30. net.blobs["gt"].reshape(args.batch_size, 1, args.image_size, args.image_size)
  31. def main(args):
  32. net = caffe.Net(args.net_file, args.weight_file, caffe.TEST)
  33. presolve(net, args)
  34. file_list = map(lambda s: s.strip(), open(args.test_manifest, 'r').readlines())
  35. for idx, line in enumerate(file_list):
  36. if idx % args.print_count == 0:
  37. print "Processed %d/%d Images" % (idx, len(file_list))
  38. tokens = line.split(',')
  39. f = tokens[0]
  40. resolved = os.path.join(args.dataset_dir, f)
  41. im = cv2.imread(resolved, 1 if args.color else 0)
  42. _input = args.scale * (cv2.resize(im, (args.image_size, args.image_size)) - args.mean)
  43. if _input.ndim > 2:
  44. _input = np.transpose(_input, (2, 0, 1))
  45. output = predict(net, _input, 'out', args)
  46. out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + ".png")
  47. cv2.imwrite(out_fn, (255 * output).astype(np.uint8))
  48. def get_args():
  49. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  50. parser.add_argument("net_file",
  51. help="The deploy.prototxt")
  52. parser.add_argument("weight_file",
  53. help="The .caffemodel")
  54. parser.add_argument("dataset_dir",
  55. help="The dataset to be evaluated")
  56. parser.add_argument("test_manifest",
  57. help="txt file listing images to train on")
  58. parser.add_argument("--out-dir", default='out', type=str,
  59. help="Dump images")
  60. parser.add_argument("--gpu", type=int, default=0,
  61. help="GPU to use for running the network")
  62. parser.add_argument("-c", "--color", default=False, action='store_true',
  63. help="Training batch size")
  64. parser.add_argument("-m", "--mean", type=float, default=127.,
  65. help="Mean value for data preprocessing")
  66. parser.add_argument("-s", "--scale", type=float, default=0.0039,
  67. help="Optional pixel scale factor")
  68. parser.add_argument("-b", "--batch-size", default=2, type=int,
  69. help="Training batch size")
  70. parser.add_argument("--image-size", default=256, type=int,
  71. help="Size of images for input to prediction")
  72. parser.add_argument("--print-count", default=10, type=int,
  73. help="Print interval")
  74. args = parser.parse_args()
  75. print args
  76. return args
  77. if __name__ == "__main__":
  78. args = get_args()
  79. safe_mkdir(args.out_dir)
  80. if args.gpu >= 0:
  81. caffe.set_device(args.gpu)
  82. caffe.set_mode_gpu()
  83. else:
  84. caffe.set_mode_cpu()
  85. main(args)