Page boundary detection in historical documents
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

307 lignes
8.6KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import collections
  5. import argparse
  6. import numpy as np
  7. import matplotlib
  8. matplotlib.use("AGG")
  9. import matplotlib.pyplot as plt
  10. import caffe
  11. import cv2
  12. import random
  13. def safe_mkdir(_dir):
  14. try:
  15. os.makedirs(_dir)
  16. except:
  17. pass
  18. def dump_debug(out_dir, data, dump_images=False):
  19. pred_image_dir = os.path.join(out_dir, 'pred_images')
  20. safe_mkdir(pred_image_dir)
  21. for idx in range(len(data['images'])):
  22. fn = data['filenames'][idx]
  23. preds = data['predictions'][idx]
  24. fn_base = fn.replace('/', '_')[:-4]
  25. out_fn = os.path.join(pred_image_dir, fn_base + ".png")
  26. cv2.imwrite(out_fn, 255 * preds)
  27. def predict(network, im, output_blob, args):
  28. if im.ndim > 2:
  29. im = np.transpose(im, axes=(2, 0, 1))
  30. network.blobs["data"].data[0,:,:,:] = im
  31. network.forward()
  32. response = network.blobs[output_blob].data[0,:].copy()
  33. return np.argmax(response, axis=0)
  34. def iou(im1, im2):
  35. num_intersect = np.sum(np.logical_and(im1, im2))
  36. num_union = num_intersect + np.sum(np.logical_xor(im1, im2))
  37. return float(num_intersect) / num_union
  38. def prf(im1, im2):
  39. num_intersect = np.sum(np.logical_and(im1, im2))
  40. num_1 = np.sum(im1)
  41. num_2 = np.sum(im2)
  42. p = num_intersect / float(num_1)
  43. r = num_intersect / float(num_2)
  44. f = (2 * p * r) / (p + r) if (p + r) else 0
  45. return p, r, f
  46. def update_predictions(net, data, args):
  47. print("Starting Predictions")
  48. total_iou = 0
  49. total_p = 0
  50. total_r = 0
  51. total_f = 0
  52. for idx in range(len(data['images'])):
  53. im = cv2.resize(data['images'][idx], (args.image_size, args.image_size))
  54. outputs = predict(net, im, 'out', args)
  55. data['predictions'][idx] = outputs.copy()
  56. width, height = data['original_size'][idx]
  57. outputs = cv2.resize(outputs, (width, height), interpolation=cv2.INTER_NEAREST)
  58. total_iou += iou(outputs, data['original_gt'][idx])
  59. p, r, f = prf(outputs, data['original_gt'][idx])
  60. total_p += p
  61. total_r += r
  62. total_f += f
  63. if idx and idx % args.print_count == 0:
  64. print("\tPredicted %d/%d" % (idx, len(data['images'])))
  65. avg_iou = total_iou / len(data['images'])
  66. avg_p = total_p / len(data['images'])
  67. avg_r = total_r / len(data['images'])
  68. avg_f = total_f / len(data['images'])
  69. return avg_iou, avg_p, avg_r, avg_f
  70. def load_data(manifest, _dir, size, color=False):
  71. dataset = collections.defaultdict(list)
  72. file_list = [s.strip() for s in open(manifest, 'r').readlines()]
  73. for line in file_list:
  74. tokens = line.split(',')
  75. f = tokens[0]
  76. coords = list(map(float, tokens[1:9]))
  77. dataset['filenames'].append(f)
  78. resolved = os.path.join(_dir, f)
  79. im = cv2.imread(resolved, 1 if color else 0)
  80. gt = np.zeros(im.shape[:2], dtype=np.uint8)
  81. cv2.fillPoly(gt, np.array(coords).reshape((4, 2)).astype(np.int32)[np.newaxis,:,:], 1)
  82. if im is None:
  83. raise Exception("Error loading %s" % resolved)
  84. height, width = im.shape[:2]
  85. im = cv2.resize(im, (size, size))
  86. dataset['original_gt'].append(gt)
  87. gt = cv2.resize(gt, (size, size), interpolation=cv2.INTER_NEAREST)
  88. dataset['images'].append(im)
  89. dataset['original_size'].append( (width, height) ) # opencv does (w,h)
  90. dataset['gt'].append(gt)
  91. return dataset
  92. def preprocess_data(data, args):
  93. for idx in range(len(data['images'])):
  94. im = data['images'][idx]
  95. im = args.scale * (im - args.mean)
  96. data['images'][idx] = im
  97. gt = data['gt'][idx]
  98. data['predictions'].append(gt.copy())
  99. def get_solver_params(f):
  100. max_iters = 0
  101. snapshot = 0
  102. for line in open(f).readlines():
  103. tokens = line.split()
  104. if tokens[0] == 'max_iter:':
  105. max_iters = int(tokens[1])
  106. if tokens[0] == 'snapshot:':
  107. snapshot = int(tokens[1])
  108. return max_iters, snapshot
  109. def presolve(net, args):
  110. net.blobs["data"].reshape(args.batch_size, 3 if args.color else 1, args.image_size, args.image_size)
  111. net.blobs["gt"].reshape(args.batch_size, 1, args.image_size, args.image_size)
  112. def set_input_data(net, data, args):
  113. for batch_idx in range(args.batch_size):
  114. im_idx = random.randint(0, len(data['images']) - 1)
  115. im = data['images'][im_idx]
  116. gt = data['gt'][im_idx]
  117. if im.ndim > 2:
  118. im = np.transpose(im, (2, 0, 1))
  119. net.blobs["data"].data[batch_idx,:,:,:] = im
  120. net.blobs["gt"].data[batch_idx,0,:,:] = gt
  121. def main(args):
  122. train_data = load_data(args.train_manifest, args.dataset_dir, args.image_size, args.color)
  123. val_data = load_data(args.val_manifest, args.dataset_dir, args.image_size, args.color)
  124. preprocess_data(train_data, args)
  125. preprocess_data(val_data, args)
  126. print("Done loading data")
  127. solver = caffe.SGDSolver(args.solver_file)
  128. max_iters, snapshot_interval = get_solver_params(args.solver_file)
  129. presolve(solver.net, args)
  130. train_iou, val_iou = [], []
  131. train_p, val_p = [], []
  132. train_r, val_r = [], []
  133. train_f, val_f = [], []
  134. for iter_num in range(max_iters + 1):
  135. set_input_data(solver.net, train_data, args)
  136. solver.step(1)
  137. if iter_num and iter_num % snapshot_interval == 0:
  138. print("Validation Prediction: %d" % iter_num)
  139. avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, val_data, args)
  140. val_iou.append((iter_num, avg_iou))
  141. val_p.append((iter_num, avg_p))
  142. val_r.append((iter_num, avg_r))
  143. val_f.append((iter_num, avg_f))
  144. if args.debug_dir:
  145. print("Dumping images")
  146. out_dir = os.path.join(args.debug_dir, 'val_%d' % iter_num)
  147. dump_debug(out_dir, val_data)
  148. if iter_num >= args.min_interval and iter_num % args.gt_interval == 0:
  149. print("Train Prediction: %d" % iter_num)
  150. avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, train_data, args)
  151. train_iou.append((iter_num, avg_iou))
  152. train_p.append((iter_num, avg_p))
  153. train_r.append((iter_num, avg_r))
  154. train_f.append((iter_num, avg_f))
  155. print("Train IOU: ", train_iou)
  156. print()
  157. print("Val IOU: ", val_iou)
  158. if args.debug_dir:
  159. plt.plot(*list(zip(*train_iou)), label='train')
  160. plt.plot(*list(zip(*val_iou)), label='val')
  161. plt.legend()
  162. plt.savefig(os.path.join(args.debug_dir, 'iou.png'))
  163. plt.clf()
  164. plt.plot(*list(zip(*train_p)), label='train')
  165. plt.plot(*list(zip(*val_p)), label='val')
  166. plt.legend()
  167. plt.savefig(os.path.join(args.debug_dir, 'precision.png'))
  168. plt.clf()
  169. plt.plot(*list(zip(*train_r)), label='train')
  170. plt.plot(*list(zip(*val_r)), label='val')
  171. plt.legend()
  172. plt.savefig(os.path.join(args.debug_dir, 'recall.png'))
  173. plt.clf()
  174. plt.plot(*list(zip(*train_f)), label='train')
  175. plt.plot(*list(zip(*val_f)), label='val')
  176. plt.legend()
  177. plt.savefig(os.path.join(args.debug_dir, 'fmeasure.png'))
  178. _ = update_predictions(solver.net, train_data, args)
  179. out_dir = os.path.join(args.debug_dir, 'train_final')
  180. dump_debug(out_dir, train_data, True)
  181. _ = update_predictions(solver.net, val_data, args)
  182. out_dir = os.path.join(args.debug_dir, 'val_final')
  183. dump_debug(out_dir, val_data, True)
  184. for name, vals in zip(['train_iou', 'val_iou', 'train_p', 'val_p',
  185. 'train_r', 'val_r', 'train_f', 'val_f'],
  186. [train_iou, val_iou, train_p, val_p,
  187. train_r, val_r, train_f, val_f]):
  188. fd = open(os.path.join(args.debug_dir, "%s.txt" % name), 'w')
  189. fd.write('%r\n' % vals)
  190. fd.close()
  191. def get_args():
  192. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  193. parser.add_argument("solver_file",
  194. help="The solver.prototxt")
  195. parser.add_argument("dataset_dir",
  196. help="The dataset to be evaluated")
  197. parser.add_argument("train_manifest",
  198. help="txt file listing images to train on")
  199. parser.add_argument("val_manifest",
  200. help="txt file listing images for validation")
  201. parser.add_argument("--gpu", type=int, default=0,
  202. help="GPU to use for running the network")
  203. parser.add_argument("-m", "--mean", type=float, default=127.,
  204. help="Mean value for data preprocessing")
  205. parser.add_argument("-s", "--scale", type=float, default=1.,
  206. help="Optional pixel scale factor")
  207. parser.add_argument("-b", "--batch-size", default=2, type=int,
  208. help="Training batch size")
  209. parser.add_argument("-c", "--color", default=False, action='store_true',
  210. help="Training batch size")
  211. parser.add_argument("--image-size", default=256, type=int,
  212. help="Size of images for input to training/prediction")
  213. parser.add_argument("--gt-interval", default=5000, type=int,
  214. help="Interval for Debug")
  215. parser.add_argument("--min-interval", default=5000, type=int,
  216. help="Miniumum iteration for Debug")
  217. parser.add_argument("--debug-dir", default='debug', type=str,
  218. help="Dump images for debugging")
  219. parser.add_argument("--print-count", default=10, type=int,
  220. help="How often to print progress")
  221. args = parser.parse_args()
  222. print(args)
  223. return args
  224. if __name__ == "__main__":
  225. args = get_args()
  226. if args.gpu >= 0:
  227. caffe.set_device(args.gpu)
  228. caffe.set_mode_gpu()
  229. else:
  230. caffe.set_mode_cpu()
  231. main(args)