Page boundary detection in historical documents
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
3.6KB

  1. #from grabCutCrop import cropMask
  2. from multiprocessing.dummy import Pool as ThreadPool
  3. import cv2
  4. import numpy as np
  5. import sys
  6. def cropMask(imageDir,saveDir,imagePath, border):
  7. interName = saveDir+'grabCut_'+imagePath.replace('/','_')+'.png'
  8. mask=cv2.imread(interName,0)
  9. if mask is None or mask.shape[0]==0:
  10. img = cv2.imread(imageDir+imagePath)
  11. if img is None:
  12. print "Could not read: "+imageDir+imagePath
  13. mask = np.zeros(img.shape[:2],np.uint8)
  14. bgdModel = np.zeros((1,65),np.float64)
  15. fgdModel = np.zeros((1,65),np.float64)
  16. rect = (border,border,img.shape[1]-2*border,img.shape[0]-2*border) #(x,y,w,h)
  17. cv2.grabCut(img,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_RECT)
  18. cv2.imwrite(interName,mask)
  19. mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
  20. revMask = 1-mask2
  21. #fix the mask so that there is only one CC
  22. num, labels, stats, centiods = cv2.connectedComponentsWithStats(revMask, 8, cv2.CV_32S)
  23. maxLabel=None
  24. maxSize=0
  25. for l in range(1,num):
  26. if stats[l,cv2.CC_STAT_AREA] > maxSize:
  27. maxSize = stats[l,cv2.CC_STAT_AREA]
  28. maxLabel = l
  29. return (labels!=maxLabel).astype('uint8')
  30. if len(sys.argv)<5:
  31. print 'This evalutes using grab cut'
  32. print 'usage: '+sys.argv[0]+' gtFile.csv imageDir saveIntermDir numThreads' # [reverse]'
  33. exit(0)
  34. gtFile=sys.argv[1]
  35. imageDir=sys.argv[2]
  36. if imageDir[-1]!='/':
  37. imageDir+='/'
  38. saveDir=sys.argv[3]
  39. if saveDir[-1]!='/':
  40. saveDir+='/'
  41. numThreads = int(sys.argv[4])
  42. reverse=False
  43. #if len(sys.argv)>5 and sys.argv[5][0]=='r':
  44. # reverse=True
  45. sumIOU=0
  46. countIOU=0
  47. scale=1
  48. print 'eval on '+gtFile
  49. outFile = gtFile+'_fullgrab.res'
  50. #numLines=0
  51. #try:
  52. # with open(outFile,'r') as f:
  53. # numLines = len(f.readlines())
  54. #except IOError:
  55. # numLines=0
  56. try:
  57. out = open(outFile,'a')
  58. except IOError:
  59. out = open(outFile,'w')
  60. def worker(line):
  61. global imageDir, saveDir, reverse
  62. try:
  63. p = line.split(',')
  64. imagePath = p[0]
  65. x1 = float(p[1])
  66. y1 = float(p[2])
  67. x2 = float(p[3])
  68. y2 = float(p[4])
  69. x3 = float(p[5])
  70. y3 = float(p[6])
  71. x4 = float(p[7])
  72. y4 = float(p[8])
  73. if reverse:
  74. tmpX=x4
  75. tmpY=y4
  76. x4=x3
  77. y4=y3
  78. x3=tmpX
  79. y4=tmpY
  80. #type = p[9]
  81. if x1<0:
  82. return None,None
  83. #cc+=1
  84. #if cc<=numLines:
  85. # return None, None
  86. #image = cv2.imread(imageDir+imagePath)
  87. #if image.shape[0]==0:
  88. # print 'failed to open '+imageDir+imagePath
  89. #image = cv2.resize(image,(0,0),None,scale,scale)
  90. mask = cropMask(imageDir,saveDir,imagePath,5)
  91. gtMask = np.zeros(mask.shape,np.uint8)
  92. cv2.fillConvexPoly(gtMask, np.array([[x1,y1], [x2,y2], [x3,y3], [x4,y4]], np.int32), 1, 8)
  93. intersection = np.sum(mask&gtMask)
  94. union = np.sum(mask|gtMask)
  95. if float(intersection)/union < 0.6:
  96. cv2.imwrite(saveDir+"ERR_"+imagePath.replace('/','_')+'.png',mask*255);
  97. return imagePath, float(intersection)/union
  98. except:
  99. print 'Error: '
  100. print sys.exc_info()
  101. with open(gtFile) as f:
  102. pool = ThreadPool(numThreads)
  103. results = pool.map(worker, f.readlines())
  104. for (imagePath, iou) in results:
  105. if imagePath is not None:
  106. sumIOU += iou
  107. out.write(imagePath+' '+str(iou)+'\n')
  108. countIOU += 1
  109. out.write('mean IOU for '+gtFile+': '+str(sumIOU/countIOU)+'\n')
  110. out.close()
  111. print 'mean IOU for '+gtFile+': '+str(sumIOU/countIOU)