Page boundary detection in historical documents
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
3.8KB

  1. import cv2
  2. import os
  3. import numpy as np
  4. import sys
  5. def draw_poly(img, bounding_poly):
  6. pts = np.array(bounding_poly, np.int32)
  7. #http://stackoverflow.com/a/15343106/3479446
  8. mask = np.zeros(img.shape[:2], dtype=np.uint8)
  9. roi_corners = np.array([pts], dtype=np.int32)
  10. ignore_mask_color = (255,)
  11. cv2.fillPoly(mask, roi_corners, ignore_mask_color, lineType=cv2.LINE_8)
  12. return mask
  13. def post_process(img):
  14. # img = open_close(img)
  15. img = get_largest_cc(img)
  16. img = fill_holes(img)
  17. # img = min_area_rectangle(img)
  18. img, coords = improve_min_area_rectangle(img)
  19. return img, coords
  20. def open_close(img):
  21. kernel = np.ones((3,3),np.uint8)
  22. erosion = cv2.erode(img,kernel,iterations = 15)
  23. dilation = cv2.dilate(erosion,kernel,iterations = 15)
  24. return dilation
  25. def get_largest_cc(img):
  26. img = img.copy()
  27. ret, th = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
  28. connectivity = 4
  29. output= cv2.connectedComponentsWithStats(th, connectivity, cv2.CV_32S)
  30. cnts = output[2][1:,4]
  31. largest = cnts.argmax() + 1
  32. img[output[1] != largest] = 0
  33. return img
  34. def get_iou(gt_img, pred_img):
  35. inter = gt_img & pred_img
  36. union = gt_img | pred_img
  37. iou = np.count_nonzero(inter) / float(np.count_nonzero(union))
  38. return iou
  39. def draw_box(img, box):
  40. box = np.int0(box)
  41. draw = np.zeros_like(img)
  42. cv2.drawContours(draw,[box],0,(255),-1)
  43. return draw
  44. def compute_iou(img, box):
  45. # box = np.int0(box)
  46. # draw = np.zeros_like(img)
  47. # cv2.drawContours(draw,[box],0,(255),-1)
  48. draw = draw_box(img, box)
  49. v = get_iou(img, draw)
  50. return v
  51. def step_box(img, box, step_size=1):
  52. best_val = -1
  53. best_box = None
  54. for index, x in np.ndenumerate(box):
  55. for d in [-step_size, step_size]:
  56. alt_box = box.copy()
  57. alt_box[index] = x + d
  58. v = compute_iou(img, alt_box)
  59. if best_val < v:
  60. best_val = v
  61. best_box = alt_box
  62. return best_val, best_box
  63. def improve_min_area_rectangle(img):
  64. img = img.copy()
  65. _, contours,_ = cv2.findContours(img, 1, 2)
  66. cnt = contours[0]
  67. rect = cv2.minAreaRect(cnt)
  68. box = cv2.boxPoints(rect)
  69. best_val = compute_iou(img, box)
  70. best_box = box
  71. while True:
  72. new_val, new_box = step_box(img, best_box, step_size=1)
  73. # print new_val
  74. if new_val <= best_val:
  75. break
  76. best_val = new_val
  77. best_box = new_box
  78. return draw_box(img, best_box), best_box
  79. def min_area_rectangle(img):
  80. img = img.copy()
  81. _, contours,_ = cv2.findContours(img, 1, 2)
  82. cnt = contours[0]
  83. rect = cv2.minAreaRect(cnt)
  84. box = cv2.boxPoints(rect)
  85. box = np.int0(box)
  86. draw = np.zeros_like(img)
  87. cv2.drawContours(draw,[box],0,(255),-1)
  88. return draw
  89. def fill_holes(img):
  90. im_th = img.copy()
  91. # Copy the thresholded image.
  92. im_floodfill = im_th.copy()
  93. # Mask used to flood filling.
  94. # Notice the size needs to be 2 pixels than the image.
  95. h, w = im_th.shape[:2]
  96. mask = np.zeros((h+2, w+2), np.uint8)
  97. # Floodfill from point (0, 0)
  98. if img[0,0] != 0:
  99. print "WARNING: Filling something you shouldn't"
  100. cv2.floodFill(im_floodfill, mask, (0,0), 255);
  101. # Invert floodfilled image
  102. im_floodfill_inv = cv2.bitwise_not(im_floodfill)
  103. # Combine the two images to get the foreground.
  104. im_out = im_th | im_floodfill_inv
  105. return im_out
  106. if __name__ == "__main__":
  107. pred_folder = sys.argv[1]
  108. out_folder = sys.argv[2]
  109. pred_imgs = {}
  110. for root, folders, files in os.walk(pred_folder):
  111. for f in files:
  112. if f.endswith(".png"):
  113. pred_imgs[f] = os.path.join(root, f)
  114. for k in pred_imgs:
  115. pred_img = cv2.imread(pred_imgs[k], 0)
  116. post_img = post_process(pred_img)
  117. cv2.imwrite(os.path.join(out_folder, k), post_img)