Page boundary detection in historical documents
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

307 lines
8.6KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import collections
  5. import argparse
  6. import numpy as np
  7. import matplotlib
  8. matplotlib.use("AGG")
  9. import matplotlib.pyplot as plt
  10. import caffe
  11. import cv2
  12. import random
  13. def safe_mkdir(_dir):
  14. try:
  15. os.makedirs(_dir)
  16. except:
  17. pass
  18. def dump_debug(out_dir, data, dump_images=False):
  19. pred_image_dir = os.path.join(out_dir, 'pred_images')
  20. safe_mkdir(pred_image_dir)
  21. for idx in range(len(data['images'])):
  22. fn = data['filenames'][idx]
  23. preds = data['predictions'][idx]
  24. fn_base = fn.replace('/', '_')[:-4]
  25. out_fn = os.path.join(pred_image_dir, fn_base + ".png")
  26. cv2.imwrite(out_fn, 255 * preds)
  27. def predict(network, im, output_blob, args):
  28. if im.ndim > 2:
  29. im = np.transpose(im, axes=(2, 0, 1))
  30. network.blobs["data"].data[0,:,:,:] = im
  31. network.forward()
  32. response = network.blobs[output_blob].data[0,:].copy()
  33. return np.argmax(response, axis=0)
  34. def iou(im1, im2):
  35. num_intersect = np.sum(np.logical_and(im1, im2))
  36. num_union = num_intersect + np.sum(np.logical_xor(im1, im2))
  37. return float(num_intersect) / num_union
  38. def prf(im1, im2):
  39. num_intersect = np.sum(np.logical_and(im1, im2))
  40. num_1 = np.sum(im1)
  41. num_2 = np.sum(im2)
  42. p = num_intersect / float(num_1)
  43. r = num_intersect / float(num_2)
  44. f = (2 * p * r) / (p + r) if (p + r) else 0
  45. return p, r, f
  46. def update_predictions(net, data, args):
  47. print("Starting Predictions")
  48. total_iou = 0
  49. total_p = 0
  50. total_r = 0
  51. total_f = 0
  52. for idx in range(len(data['images'])):
  53. im = cv2.resize(data['images'][idx], (args.image_size, args.image_size))
  54. outputs = predict(net, im, 'out', args)
  55. data['predictions'][idx] = outputs.copy()
  56. width, height = data['original_size'][idx]
  57. outputs = cv2.resize(outputs, (width, height), interpolation=cv2.INTER_NEAREST)
  58. total_iou += iou(outputs, data['original_gt'][idx])
  59. p, r, f = prf(outputs, data['original_gt'][idx])
  60. total_p += p
  61. total_r += r
  62. total_f += f
  63. if idx and idx % args.print_count == 0:
  64. print("\tPredicted %d/%d" % (idx, len(data['images'])))
  65. avg_iou = total_iou / len(data['images'])
  66. avg_p = total_p / len(data['images'])
  67. avg_r = total_r / len(data['images'])
  68. avg_f = total_f / len(data['images'])
  69. return avg_iou, avg_p, avg_r, avg_f
  70. def load_data(manifest, _dir, size, color=False):
  71. dataset = collections.defaultdict(list)
  72. file_list = [s.strip() for s in open(manifest, 'r').readlines()]
  73. for line in file_list:
  74. tokens = line.split(',')
  75. f = tokens[0]
  76. coords = list(map(float, tokens[1:9]))
  77. dataset['filenames'].append(f)
  78. resolved = os.path.join(_dir, f)
  79. im = cv2.imread(resolved, 1 if color else 0)
  80. gt = np.zeros(im.shape[:2], dtype=np.uint8)
  81. cv2.fillPoly(gt, np.array(coords).reshape((4, 2)).astype(np.int32)[np.newaxis,:,:], 1)
  82. if im is None:
  83. raise Exception("Error loading %s" % resolved)
  84. height, width = im.shape[:2]
  85. im = cv2.resize(im, (size, size))
  86. dataset['original_gt'].append(gt)
  87. gt = cv2.resize(gt, (size, size), interpolation=cv2.INTER_NEAREST)
  88. dataset['images'].append(im)
  89. dataset['original_size'].append( (width, height) ) # opencv does (w,h)
  90. dataset['gt'].append(gt)
  91. return dataset
  92. def preprocess_data(data, args):
  93. for idx in range(len(data['images'])):
  94. im = data['images'][idx]
  95. im = args.scale * (im - args.mean)
  96. data['images'][idx] = im
  97. gt = data['gt'][idx]
  98. data['predictions'].append(gt.copy())
  99. def get_solver_params(f):
  100. max_iters = 0
  101. snapshot = 0
  102. for line in open(f).readlines():
  103. tokens = line.split()
  104. if tokens[0] == 'max_iter:':
  105. max_iters = int(tokens[1])
  106. if tokens[0] == 'snapshot:':
  107. snapshot = int(tokens[1])
  108. return max_iters, snapshot
  109. def presolve(net, args):
  110. net.blobs["data"].reshape(args.batch_size, 3 if args.color else 1, args.image_size, args.image_size)
  111. net.blobs["gt"].reshape(args.batch_size, 1, args.image_size, args.image_size)
  112. def set_input_data(net, data, args):
  113. for batch_idx in range(args.batch_size):
  114. im_idx = random.randint(0, len(data['images']) - 1)
  115. im = data['images'][im_idx]
  116. gt = data['gt'][im_idx]
  117. if im.ndim > 2:
  118. im = np.transpose(im, (2, 0, 1))
  119. net.blobs["data"].data[batch_idx,:,:,:] = im
  120. net.blobs["gt"].data[batch_idx,0,:,:] = gt
  121. def main(args):
  122. train_data = load_data(args.train_manifest, args.dataset_dir, args.image_size, args.color)
  123. val_data = load_data(args.val_manifest, args.dataset_dir, args.image_size, args.color)
  124. preprocess_data(train_data, args)
  125. preprocess_data(val_data, args)
  126. print("Done loading data")
  127. solver = caffe.SGDSolver(args.solver_file)
  128. max_iters, snapshot_interval = get_solver_params(args.solver_file)
  129. presolve(solver.net, args)
  130. train_iou, val_iou = [], []
  131. train_p, val_p = [], []
  132. train_r, val_r = [], []
  133. train_f, val_f = [], []
  134. for iter_num in range(max_iters + 1):
  135. set_input_data(solver.net, train_data, args)
  136. solver.step(1)
  137. if iter_num and iter_num % snapshot_interval == 0:
  138. print("Validation Prediction: %d" % iter_num)
  139. avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, val_data, args)
  140. val_iou.append((iter_num, avg_iou))
  141. val_p.append((iter_num, avg_p))
  142. val_r.append((iter_num, avg_r))
  143. val_f.append((iter_num, avg_f))
  144. if args.debug_dir:
  145. print("Dumping images")
  146. out_dir = os.path.join(args.debug_dir, 'val_%d' % iter_num)
  147. dump_debug(out_dir, val_data)
  148. if iter_num >= args.min_interval and iter_num % args.gt_interval == 0:
  149. print("Train Prediction: %d" % iter_num)
  150. avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, train_data, args)
  151. train_iou.append((iter_num, avg_iou))
  152. train_p.append((iter_num, avg_p))
  153. train_r.append((iter_num, avg_r))
  154. train_f.append((iter_num, avg_f))
  155. print("Train IOU: ", train_iou)
  156. print()
  157. print("Val IOU: ", val_iou)
  158. if args.debug_dir:
  159. plt.plot(*list(zip(*train_iou)), label='train')
  160. plt.plot(*list(zip(*val_iou)), label='val')
  161. plt.legend()
  162. plt.savefig(os.path.join(args.debug_dir, 'iou.png'))
  163. plt.clf()
  164. plt.plot(*list(zip(*train_p)), label='train')
  165. plt.plot(*list(zip(*val_p)), label='val')
  166. plt.legend()
  167. plt.savefig(os.path.join(args.debug_dir, 'precision.png'))
  168. plt.clf()
  169. plt.plot(*list(zip(*train_r)), label='train')
  170. plt.plot(*list(zip(*val_r)), label='val')
  171. plt.legend()
  172. plt.savefig(os.path.join(args.debug_dir, 'recall.png'))
  173. plt.clf()
  174. plt.plot(*list(zip(*train_f)), label='train')
  175. plt.plot(*list(zip(*val_f)), label='val')
  176. plt.legend()
  177. plt.savefig(os.path.join(args.debug_dir, 'fmeasure.png'))
  178. _ = update_predictions(solver.net, train_data, args)
  179. out_dir = os.path.join(args.debug_dir, 'train_final')
  180. dump_debug(out_dir, train_data, True)
  181. _ = update_predictions(solver.net, val_data, args)
  182. out_dir = os.path.join(args.debug_dir, 'val_final')
  183. dump_debug(out_dir, val_data, True)
  184. for name, vals in zip(['train_iou', 'val_iou', 'train_p', 'val_p',
  185. 'train_r', 'val_r', 'train_f', 'val_f'],
  186. [train_iou, val_iou, train_p, val_p,
  187. train_r, val_r, train_f, val_f]):
  188. fd = open(os.path.join(args.debug_dir, "%s.txt" % name), 'w')
  189. fd.write('%r\n' % vals)
  190. fd.close()
  191. def get_args():
  192. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  193. parser.add_argument("solver_file",
  194. help="The solver.prototxt")
  195. parser.add_argument("dataset_dir",
  196. help="The dataset to be evaluated")
  197. parser.add_argument("train_manifest",
  198. help="txt file listing images to train on")
  199. parser.add_argument("val_manifest",
  200. help="txt file listing images for validation")
  201. parser.add_argument("--gpu", type=int, default=0,
  202. help="GPU to use for running the network")
  203. parser.add_argument("-m", "--mean", type=float, default=127.,
  204. help="Mean value for data preprocessing")
  205. parser.add_argument("-s", "--scale", type=float, default=1.,
  206. help="Optional pixel scale factor")
  207. parser.add_argument("-b", "--batch-size", default=2, type=int,
  208. help="Training batch size")
  209. parser.add_argument("-c", "--color", default=False, action='store_true',
  210. help="Training batch size")
  211. parser.add_argument("--image-size", default=256, type=int,
  212. help="Size of images for input to training/prediction")
  213. parser.add_argument("--gt-interval", default=5000, type=int,
  214. help="Interval for Debug")
  215. parser.add_argument("--min-interval", default=5000, type=int,
  216. help="Miniumum iteration for Debug")
  217. parser.add_argument("--debug-dir", default='debug', type=str,
  218. help="Dump images for debugging")
  219. parser.add_argument("--print-count", default=10, type=int,
  220. help="How often to print progress")
  221. args = parser.parse_args()
  222. print(args)
  223. return args
  224. if __name__ == "__main__":
  225. args = get_args()
  226. if args.gpu >= 0:
  227. caffe.set_device(args.gpu)
  228. caffe.set_mode_gpu()
  229. else:
  230. caffe.set_mode_cpu()
  231. main(args)