|
- #!/usr/bin/python
-
- import os
- import sys
- import collections
- import argparse
- import numpy as np
- import matplotlib
- matplotlib.use("AGG")
- import matplotlib.pyplot as plt
- import caffe
- import cv2
- import random
-
-
-
- def safe_mkdir(_dir):
- try:
- os.makedirs(_dir)
- except:
- pass
-
-
- def dump_debug(out_dir, data, dump_images=False):
- pred_image_dir = os.path.join(out_dir, 'pred_images')
- safe_mkdir(pred_image_dir)
-
- for idx in range(len(data['images'])):
- fn = data['filenames'][idx]
- preds = data['predictions'][idx]
-
- fn_base = fn.replace('/', '_')[:-4]
- out_fn = os.path.join(pred_image_dir, fn_base + ".png")
- cv2.imwrite(out_fn, 255 * preds)
-
-
- def predict(network, im, output_blob, args):
- if im.ndim > 2:
- im = np.transpose(im, axes=(2, 0, 1))
- network.blobs["data"].data[0,:,:,:] = im
- network.forward()
-
- response = network.blobs[output_blob].data[0,:].copy()
- return np.argmax(response, axis=0)
-
-
- def iou(im1, im2):
- num_intersect = np.sum(np.logical_and(im1, im2))
- num_union = num_intersect + np.sum(np.logical_xor(im1, im2))
- return float(num_intersect) / num_union
-
-
- def prf(im1, im2):
- num_intersect = np.sum(np.logical_and(im1, im2))
- num_1 = np.sum(im1)
- num_2 = np.sum(im2)
- p = num_intersect / float(num_1)
- r = num_intersect / float(num_2)
- f = (2 * p * r) / (p + r) if (p + r) else 0
- return p, r, f
-
-
- def update_predictions(net, data, args):
- print("Starting Predictions")
-
- total_iou = 0
- total_p = 0
- total_r = 0
- total_f = 0
- for idx in range(len(data['images'])):
- im = cv2.resize(data['images'][idx], (args.image_size, args.image_size))
-
- outputs = predict(net, im, 'out', args)
- data['predictions'][idx] = outputs.copy()
-
- width, height = data['original_size'][idx]
- outputs = cv2.resize(outputs, (width, height), interpolation=cv2.INTER_NEAREST)
- total_iou += iou(outputs, data['original_gt'][idx])
-
- p, r, f = prf(outputs, data['original_gt'][idx])
- total_p += p
- total_r += r
- total_f += f
-
-
- if idx and idx % args.print_count == 0:
- print("\tPredicted %d/%d" % (idx, len(data['images'])))
- avg_iou = total_iou / len(data['images'])
- avg_p = total_p / len(data['images'])
- avg_r = total_r / len(data['images'])
- avg_f = total_f / len(data['images'])
- return avg_iou, avg_p, avg_r, avg_f
-
-
- def load_data(manifest, _dir, size, color=False):
- dataset = collections.defaultdict(list)
- file_list = [s.strip() for s in open(manifest, 'r').readlines()]
- for line in file_list:
- tokens = line.split(',')
- f = tokens[0]
- coords = list(map(float, tokens[1:9]))
-
- dataset['filenames'].append(f)
-
- resolved = os.path.join(_dir, f)
- im = cv2.imread(resolved, 1 if color else 0)
- gt = np.zeros(im.shape[:2], dtype=np.uint8)
- cv2.fillPoly(gt, np.array(coords).reshape((4, 2)).astype(np.int32)[np.newaxis,:,:], 1)
- if im is None:
- raise Exception("Error loading %s" % resolved)
- height, width = im.shape[:2]
- im = cv2.resize(im, (size, size))
- dataset['original_gt'].append(gt)
- gt = cv2.resize(gt, (size, size), interpolation=cv2.INTER_NEAREST)
- dataset['images'].append(im)
- dataset['original_size'].append( (width, height) ) # opencv does (w,h)
- dataset['gt'].append(gt)
-
- return dataset
-
-
- def preprocess_data(data, args):
- for idx in range(len(data['images'])):
- im = data['images'][idx]
- im = args.scale * (im - args.mean)
- data['images'][idx] = im
-
- gt = data['gt'][idx]
- data['predictions'].append(gt.copy())
-
-
- def get_solver_params(f):
- max_iters = 0
- snapshot = 0
-
- for line in open(f).readlines():
- tokens = line.split()
- if tokens[0] == 'max_iter:':
- max_iters = int(tokens[1])
- if tokens[0] == 'snapshot:':
- snapshot = int(tokens[1])
- return max_iters, snapshot
-
-
- def presolve(net, args):
- net.blobs["data"].reshape(args.batch_size, 3 if args.color else 1, args.image_size, args.image_size)
- net.blobs["gt"].reshape(args.batch_size, 1, args.image_size, args.image_size)
-
-
- def set_input_data(net, data, args):
- for batch_idx in range(args.batch_size):
- im_idx = random.randint(0, len(data['images']) - 1)
- im = data['images'][im_idx]
- gt = data['gt'][im_idx]
-
- if im.ndim > 2:
- im = np.transpose(im, (2, 0, 1))
-
- net.blobs["data"].data[batch_idx,:,:,:] = im
- net.blobs["gt"].data[batch_idx,0,:,:] = gt
-
-
- def main(args):
-
- train_data = load_data(args.train_manifest, args.dataset_dir, args.image_size, args.color)
- val_data = load_data(args.val_manifest, args.dataset_dir, args.image_size, args.color)
-
- preprocess_data(train_data, args)
- preprocess_data(val_data, args)
-
- print("Done loading data")
-
- solver = caffe.SGDSolver(args.solver_file)
- max_iters, snapshot_interval = get_solver_params(args.solver_file)
-
- presolve(solver.net, args)
- train_iou, val_iou = [], []
- train_p, val_p = [], []
- train_r, val_r = [], []
- train_f, val_f = [], []
-
- for iter_num in range(max_iters + 1):
- set_input_data(solver.net, train_data, args)
- solver.step(1)
-
- if iter_num and iter_num % snapshot_interval == 0:
- print("Validation Prediction: %d" % iter_num)
- avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, val_data, args)
- val_iou.append((iter_num, avg_iou))
- val_p.append((iter_num, avg_p))
- val_r.append((iter_num, avg_r))
- val_f.append((iter_num, avg_f))
- if args.debug_dir:
- print("Dumping images")
- out_dir = os.path.join(args.debug_dir, 'val_%d' % iter_num)
- dump_debug(out_dir, val_data)
-
- if iter_num >= args.min_interval and iter_num % args.gt_interval == 0:
-
- print("Train Prediction: %d" % iter_num)
- avg_iou, avg_p, avg_r, avg_f = update_predictions(solver.net, train_data, args)
- train_iou.append((iter_num, avg_iou))
- train_p.append((iter_num, avg_p))
- train_r.append((iter_num, avg_r))
- train_f.append((iter_num, avg_f))
-
- print("Train IOU: ", train_iou)
- print()
- print("Val IOU: ", val_iou)
- if args.debug_dir:
- plt.plot(*list(zip(*train_iou)), label='train')
- plt.plot(*list(zip(*val_iou)), label='val')
- plt.legend()
- plt.savefig(os.path.join(args.debug_dir, 'iou.png'))
-
-
- plt.clf()
- plt.plot(*list(zip(*train_p)), label='train')
- plt.plot(*list(zip(*val_p)), label='val')
- plt.legend()
- plt.savefig(os.path.join(args.debug_dir, 'precision.png'))
-
- plt.clf()
- plt.plot(*list(zip(*train_r)), label='train')
- plt.plot(*list(zip(*val_r)), label='val')
- plt.legend()
- plt.savefig(os.path.join(args.debug_dir, 'recall.png'))
-
- plt.clf()
- plt.plot(*list(zip(*train_f)), label='train')
- plt.plot(*list(zip(*val_f)), label='val')
- plt.legend()
- plt.savefig(os.path.join(args.debug_dir, 'fmeasure.png'))
-
- _ = update_predictions(solver.net, train_data, args)
- out_dir = os.path.join(args.debug_dir, 'train_final')
- dump_debug(out_dir, train_data, True)
-
- _ = update_predictions(solver.net, val_data, args)
- out_dir = os.path.join(args.debug_dir, 'val_final')
- dump_debug(out_dir, val_data, True)
-
- for name, vals in zip(['train_iou', 'val_iou', 'train_p', 'val_p',
- 'train_r', 'val_r', 'train_f', 'val_f'],
- [train_iou, val_iou, train_p, val_p,
- train_r, val_r, train_f, val_f]):
- fd = open(os.path.join(args.debug_dir, "%s.txt" % name), 'w')
- fd.write('%r\n' % vals)
- fd.close()
-
-
- def get_args():
- parser = argparse.ArgumentParser(description="Outputs binary predictions")
-
- parser.add_argument("solver_file",
- help="The solver.prototxt")
- parser.add_argument("dataset_dir",
- help="The dataset to be evaluated")
- parser.add_argument("train_manifest",
- help="txt file listing images to train on")
- parser.add_argument("val_manifest",
- help="txt file listing images for validation")
-
- parser.add_argument("--gpu", type=int, default=0,
- help="GPU to use for running the network")
-
- parser.add_argument("-m", "--mean", type=float, default=127.,
- help="Mean value for data preprocessing")
- parser.add_argument("-s", "--scale", type=float, default=1.,
- help="Optional pixel scale factor")
- parser.add_argument("-b", "--batch-size", default=2, type=int,
- help="Training batch size")
- parser.add_argument("-c", "--color", default=False, action='store_true',
- help="Training batch size")
-
- parser.add_argument("--image-size", default=256, type=int,
- help="Size of images for input to training/prediction")
-
- parser.add_argument("--gt-interval", default=5000, type=int,
- help="Interval for Debug")
- parser.add_argument("--min-interval", default=5000, type=int,
- help="Miniumum iteration for Debug")
-
- parser.add_argument("--debug-dir", default='debug', type=str,
- help="Dump images for debugging")
- parser.add_argument("--print-count", default=10, type=int,
- help="How often to print progress")
-
- args = parser.parse_args()
- print(args)
-
- return args
-
-
- if __name__ == "__main__":
- args = get_args()
-
- if args.gpu >= 0:
- caffe.set_device(args.gpu)
- caffe.set_mode_gpu()
- else:
- caffe.set_mode_cpu()
-
- main(args)
-
|