Page boundary detection in historical documents
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

308 行
8.5KB

  1. #!/usr/bin/python
  2. import os
  3. import sys
  4. import collections
  5. import argparse
  6. import numpy as np
  7. import matplotlib
  8. matplotlib.use("AGG")
  9. import matplotlib.pyplot as plt
  10. import caffe
  11. import cv2
  12. import random
  13. import scipy.ndimage as nd
  14. def safe_mkdir(_dir):
  15. try:
  16. os.makedirs(_dir)
  17. except:
  18. pass
  19. def dump_debug(out_dir, data, dump_images=False):
  20. pred_image_dir = os.path.join(out_dir, 'pred_images')
  21. safe_mkdir(pred_image_dir)
  22. for idx in xrange(len(data['images'])):
  23. fn = data['filenames'][idx]
  24. preds = data['predictions'][idx]
  25. fn_base = fn.replace('/', '_')[:-4]
  26. out_fn = os.path.join(pred_image_dir, fn_base + ".png")
  27. cv2.imwrite(out_fn, 255 * preds)
  28. def predict(network, im, output_blob, args):
  29. if im.ndim > 2:
  30. im = np.transpose(im, axes=(2, 0, 1))
  31. network.blobs["data"].data[0,:,:,:] = im
  32. network.forward()
  33. response = network.blobs[output_blob].data[0,:].copy()
  34. return np.argmax(response, axis=0)
  35. def iou(im1, im2):
  36. num_intersect = np.sum(np.logical_and(im1, im2))
  37. num_union = num_intersect + np.sum(np.logical_xor(im1, im2))
  38. return float(num_intersect) / num_union
  39. def prf(im1, im2):
  40. num_intersect = np.sum(np.logical_and(im1, im2))
  41. num_1 = np.sum(im1)
  42. num_2 = np.sum(im2)
  43. p = num_intersect / float(num_1)
  44. r = num_intersect / float(num_2)
  45. f = (2 * p * r) / (p + r) if (p + r) else 0
  46. return p, r, f
  47. def update_predictions(net, data, args):
  48. print "Starting Predictions"
  49. total_iou = 0
  50. total_p = 0
  51. total_r = 0
  52. total_f = 0
  53. for idx in xrange(len(data['images'])):
  54. im = cv2.resize(data['images'][idx], (args.image_size, args.image_size))
  55. outputs = predict(net, im, 'out', args)
  56. data['predictions'][idx] = outputs.copy()
  57. width, height = data['original_size'][idx]
  58. outputs = cv2.resize(outputs, (width, height), interpolation=cv2.INTER_NEAREST)
  59. total_iou += iou(outputs, data['original_gt'][idx])
  60. p, r, f = prf(outputs, data['original_gt'][idx])
  61. total_p += p
  62. total_r += r
  63. total_f += f
  64. if idx and idx % args.print_count == 0:
  65. print "\tPredicted %d/%d" % (idx, len(data['images']))
  66. avg_iou = total_iou / len(data['images'])
  67. avg_p = total_p / len(data['images'])
  68. avg_r = total_r / len(data['images'])
  69. avg_f = total_f / len(data['images'])
  70. return avg_iou, avg_p, avg_r, avg_f
  71. def load_data(manifest, _dir, size, color=False):
  72. dataset = collections.defaultdict(list)
  73. file_list = map(lambda s: s.strip(), open(manifest, 'r').readlines())
  74. for line in file_list:
  75. tokens = line.split(',')
  76. f = tokens[0]
  77. coords = map(float, tokens[1:9])
  78. dataset['filenames'].append(f)
  79. resolved = os.path.join(_dir, f)
  80. im = cv2.imread(resolved, 1 if color else 0)
  81. gt = np.zeros(im.shape[:2], dtype=np.uint8)
  82. cv2.fillPoly(gt, np.array(coords).reshape((4, 2)).astype(np.int32)[np.newaxis,:,:], 1)
  83. if im is None:
  84. raise Exception("Error loading %s" % resolved)
  85. height, width = im.shape[:2]
  86. im = cv2.resize(im, (size, size))
  87. dataset['original_gt'].append(gt)
  88. gt = cv2.resize(gt, (size, size), interpolation=cv2.INTER_NEAREST)
  89. dataset['images'].append(im)
  90. dataset['original_size'].append( (width, height) ) # opencv does (w,h)
  91. dataset['gt'].append(gt)
  92. return dataset
  93. def preprocess_data(data, args):
  94. for idx in xrange(len(data['images'])):
  95. im = data['images'][idx]
  96. im = args.scale * (im - args.mean)
  97. data['images'][idx] = im
  98. gt = data['gt'][idx]
  99. data['predictions'].append(gt.copy())
  100. def get_solver_params(f):
  101. max_iters = 0
  102. snapshot = 0
  103. for line in open(f).readlines():
  104. tokens = line.split()
  105. if tokens[0] == 'max_iter:':
  106. max_iters = int(tokens[1])
  107. if tokens[0] == 'snapshot:':
  108. snapshot = int(tokens[1])
  109. return max_iters, snapshot
  110. def presolve(net, args):
  111. net.blobs["data"].reshape(args.batch_size, 3 if args.color else 1, args.image_size, args.image_size)
  112. net.blobs["gt"].reshape(args.batch_size, 1, args.image_size, args.image_size)
  113. def set_input_data(net, data, args):
  114. for batch_idx in xrange(args.batch_size):
  115. im_idx = random.randint(0, len(data['images']) - 1)
  116. im = data['images'][im_idx]
  117. gt = data['gt'][im_idx]
  118. if im.ndim > 2:
  119. im = np.transpose(im, (2, 0, 1))
  120. net.blobs["data"].data[batch_idx,:,:,:] = im
  121. net.blobs["gt"].data[batch_idx,0,:,:] = gt
  122. def main(args):
  123. train_data = load_data(args.train_manifest, args.dataset_dir, args.image_size, args.color)
  124. val_data = load_data(args.val_manifest, args.dataset_dir, args.image_size, args.color)
  125. preprocess_data(train_data, args)
  126. preprocess_data(val_data, args)
  127. print "Done loading data"
  128. solver = caffe.SGDSolver(args.solver_file)
  129. max_iters, snapshot_interval = get_solver_params(args.solver_file)
  130. presolve(solver.net, args)
  131. train_iou, val_iou = [], []
  132. train_p, val_p = [], []
  133. train_r, val_r = [], []
  134. train_f, val_f = [], []
  135. for iter_num in xrange(max_iters + 1):
  136. set_input_data(solver.net, train_data, args)
  137. solver.step(1)
  138. if iter_num and iter_num % snapshot_interval == 0:
  139. print "Validation Prediction: %d" % iter_num
  140. avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, val_data, args)
  141. val_iou.append((iter_num, avg_iou))
  142. val_p.append((iter_num, avg_p))
  143. val_r.append((iter_num, avg_r))
  144. val_f.append((iter_num, avg_f))
  145. if args.debug_dir:
  146. print "Dumping images"
  147. out_dir = os.path.join(args.debug_dir, 'val_%d' % iter_num)
  148. dump_debug(out_dir, val_data)
  149. if iter_num >= args.min_interval and iter_num % args.gt_interval == 0:
  150. print "Train Prediction: %d" % iter_num
  151. avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, train_data, args)
  152. train_iou.append((iter_num, avg_iou))
  153. train_p.append((iter_num, avg_p))
  154. train_r.append((iter_num, avg_r))
  155. train_f.append((iter_num, avg_f))
  156. print "Train IOU: ", train_iou
  157. print
  158. print "Val IOU: ", val_iou
  159. if args.debug_dir:
  160. plt.plot(*zip(*train_iou), label='train')
  161. plt.plot(*zip(*val_iou), label='val')
  162. plt.legend()
  163. plt.savefig(os.path.join(args.debug_dir, 'iou.png'))
  164. plt.clf()
  165. plt.plot(*zip(*train_p), label='train')
  166. plt.plot(*zip(*val_p), label='val')
  167. plt.legend()
  168. plt.savefig(os.path.join(args.debug_dir, 'precision.png'))
  169. plt.clf()
  170. plt.plot(*zip(*train_r), label='train')
  171. plt.plot(*zip(*val_r), label='val')
  172. plt.legend()
  173. plt.savefig(os.path.join(args.debug_dir, 'recall.png'))
  174. plt.clf()
  175. plt.plot(*zip(*train_f), label='train')
  176. plt.plot(*zip(*val_f), label='val')
  177. plt.legend()
  178. plt.savefig(os.path.join(args.debug_dir, 'fmeasure.png'))
  179. _ = update_predictions(solver.net, train_data, args)
  180. out_dir = os.path.join(args.debug_dir, 'train_final')
  181. dump_debug(out_dir, train_data, True)
  182. _ = update_predictions(solver.net, val_data, args)
  183. out_dir = os.path.join(args.debug_dir, 'val_final')
  184. dump_debug(out_dir, val_data, True)
  185. for name, vals in zip(['train_iou', 'val_iou', 'train_p', 'val_p',
  186. 'train_r', 'val_r', 'train_f', 'val_f'],
  187. [train_iou, val_iou, train_p, val_p,
  188. train_r, val_r, train_f, val_f]):
  189. fd = open(os.path.join(args.debug_dir, "%s.txt" % name), 'w')
  190. fd.write('%r\n' % vals)
  191. fd.close()
  192. def get_args():
  193. parser = argparse.ArgumentParser(description="Outputs binary predictions")
  194. parser.add_argument("solver_file",
  195. help="The solver.prototxt")
  196. parser.add_argument("dataset_dir",
  197. help="The dataset to be evaluated")
  198. parser.add_argument("train_manifest",
  199. help="txt file listing images to train on")
  200. parser.add_argument("val_manifest",
  201. help="txt file listing images for validation")
  202. parser.add_argument("--gpu", type=int, default=0,
  203. help="GPU to use for running the network")
  204. parser.add_argument("-m", "--mean", type=float, default=127.,
  205. help="Mean value for data preprocessing")
  206. parser.add_argument("-s", "--scale", type=float, default=1.,
  207. help="Optional pixel scale factor")
  208. parser.add_argument("-b", "--batch-size", default=2, type=int,
  209. help="Training batch size")
  210. parser.add_argument("-c", "--color", default=False, action='store_true',
  211. help="Training batch size")
  212. parser.add_argument("--image-size", default=256, type=int,
  213. help="Size of images for input to training/prediction")
  214. parser.add_argument("--gt-interval", default=5000, type=int,
  215. help="Interval for Debug")
  216. parser.add_argument("--min-interval", default=5000, type=int,
  217. help="Miniumum iteration for Debug")
  218. parser.add_argument("--debug-dir", default='debug', type=str,
  219. help="Dump images for debugging")
  220. parser.add_argument("--print-count", default=10, type=int,
  221. help="Dump images for debugging")
  222. args = parser.parse_args()
  223. print args
  224. return args
  225. if __name__ == "__main__":
  226. args = get_args()
  227. if args.gpu >= 0:
  228. caffe.set_device(args.gpu)
  229. caffe.set_mode_gpu()
  230. else:
  231. caffe.set_mode_cpu()
  232. main(args)