Page boundary detection in historical documents
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

135 lignes
3.6KB

  1. #from grabCutCrop import cropMask
  2. from multiprocessing.dummy import Pool as ThreadPool
  3. import cv2
  4. import numpy as np
  5. import sys
  6. def cropMask(imageDir,saveDir,imagePath, border):
  7. interName = saveDir+'grabCut_'+imagePath.replace('/','_')+'.png'
  8. mask=cv2.imread(interName,0)
  9. if mask is None or mask.shape[0]==0:
  10. img = cv2.imread(imageDir+imagePath)
  11. if img is None:
  12. print "Could not read: "+imageDir+imagePath
  13. mask = np.zeros(img.shape[:2],np.uint8)
  14. bgdModel = np.zeros((1,65),np.float64)
  15. fgdModel = np.zeros((1,65),np.float64)
  16. rect = (border,border,img.shape[1]-2*border,img.shape[0]-2*border) #(x,y,w,h)
  17. cv2.grabCut(img,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_RECT)
  18. cv2.imwrite(interName,mask)
  19. mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
  20. revMask = 1-mask2
  21. #fix the mask so that there is only one CC
  22. num, labels, stats, centiods = cv2.connectedComponentsWithStats(revMask, 8, cv2.CV_32S)
  23. maxLabel=None
  24. maxSize=0
  25. for l in range(1,num):
  26. if stats[l,cv2.CC_STAT_AREA] > maxSize:
  27. maxSize = stats[l,cv2.CC_STAT_AREA]
  28. maxLabel = l
  29. return (labels!=maxLabel).astype('uint8')
  30. if len(sys.argv)<5:
  31. print 'This evalutes using grab cut'
  32. print 'usage: '+sys.argv[0]+' gtFile.csv imageDir saveIntermDir numThreads' # [reverse]'
  33. exit(0)
  34. gtFile=sys.argv[1]
  35. imageDir=sys.argv[2]
  36. if imageDir[-1]!='/':
  37. imageDir+='/'
  38. saveDir=sys.argv[3]
  39. if saveDir[-1]!='/':
  40. saveDir+='/'
  41. numThreads = int(sys.argv[4])
  42. reverse=False
  43. #if len(sys.argv)>5 and sys.argv[5][0]=='r':
  44. # reverse=True
  45. sumIOU=0
  46. countIOU=0
  47. scale=1
  48. print 'eval on '+gtFile
  49. outFile = gtFile+'_fullgrab.res'
  50. #numLines=0
  51. #try:
  52. # with open(outFile,'r') as f:
  53. # numLines = len(f.readlines())
  54. #except IOError:
  55. # numLines=0
  56. try:
  57. out = open(outFile,'a')
  58. except IOError:
  59. out = open(outFile,'w')
  60. def worker(line):
  61. global imageDir, saveDir, reverse
  62. try:
  63. p = line.split(',')
  64. imagePath = p[0]
  65. x1 = float(p[1])
  66. y1 = float(p[2])
  67. x2 = float(p[3])
  68. y2 = float(p[4])
  69. x3 = float(p[5])
  70. y3 = float(p[6])
  71. x4 = float(p[7])
  72. y4 = float(p[8])
  73. if reverse:
  74. tmpX=x4
  75. tmpY=y4
  76. x4=x3
  77. y4=y3
  78. x3=tmpX
  79. y4=tmpY
  80. #type = p[9]
  81. if x1<0:
  82. return None,None
  83. #cc+=1
  84. #if cc<=numLines:
  85. # return None, None
  86. #image = cv2.imread(imageDir+imagePath)
  87. #if image.shape[0]==0:
  88. # print 'failed to open '+imageDir+imagePath
  89. #image = cv2.resize(image,(0,0),None,scale,scale)
  90. mask = cropMask(imageDir,saveDir,imagePath,5)
  91. gtMask = np.zeros(mask.shape,np.uint8)
  92. cv2.fillConvexPoly(gtMask, np.array([[x1,y1], [x2,y2], [x3,y3], [x4,y4]], np.int32), 1, 8)
  93. intersection = np.sum(mask&gtMask)
  94. union = np.sum(mask|gtMask)
  95. if float(intersection)/union < 0.6:
  96. cv2.imwrite(saveDir+"ERR_"+imagePath.replace('/','_')+'.png',mask*255);
  97. return imagePath, float(intersection)/union
  98. except:
  99. print 'Error: '
  100. print sys.exc_info()
  101. with open(gtFile) as f:
  102. pool = ThreadPool(numThreads)
  103. results = pool.map(worker, f.readlines())
  104. for (imagePath, iou) in results:
  105. if imagePath is not None:
  106. sumIOU += iou
  107. out.write(imagePath+' '+str(iou)+'\n')
  108. countIOU += 1
  109. out.write('mean IOU for '+gtFile+': '+str(sumIOU/countIOU)+'\n')
  110. out.close()
  111. print 'mean IOU for '+gtFile+': '+str(sumIOU/countIOU)