| @@ -0,0 +1,114 @@ | |||||
| #!/usr/bin/python | |||||
| import os | |||||
| import sys | |||||
| import argparse | |||||
| import numpy as np | |||||
| import caffe | |||||
| import cv2 | |||||
| import random | |||||
| def safe_mkdir(_dir): | |||||
| try: | |||||
| os.makedirs(_dir) | |||||
| except: | |||||
| pass | |||||
| def predict(network, im, output_blob, args): | |||||
| network.blobs["data"].data[0,:,:,:] = im | |||||
| network.forward() | |||||
| response = network.blobs[output_blob].data[0,:].copy() | |||||
| return np.argmax(response, axis=0) | |||||
| def presolve(net, args): | |||||
| net.blobs["data"].reshape(1, 3 if args.color else 1, args.image_size, args.image_size) | |||||
| net.blobs["gt"].reshape(1, 1, args.image_size, args.image_size) | |||||
| def main(args): | |||||
| net = caffe.Net(args.net_file, args.weight_file, caffe.TEST) | |||||
| presolve(net, args) | |||||
| file_list = map(lambda s: s.strip(), open(args.test_manifest, 'r').readlines()) | |||||
| fd = open(args.out_file, 'w') | |||||
| for idx, line in enumerate(file_list): | |||||
| if idx % args.print_count == 0: | |||||
| print "Processed %d/%d Images" % (idx, len(file_list)) | |||||
| tokens = line.split(',') | |||||
| f = tokens[0] | |||||
| resolved = os.path.join(args.dataset_dir, f) | |||||
| im = cv2.imread(resolved, 1 if args.color else 0) | |||||
| _input = args.scale * (cv2.resize(im, (args.image_size, args.image_size)) - args.mean) | |||||
| if _input.ndim > 2: | |||||
| _input = np.transpose(_input, (2, 0, 1)) | |||||
| raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8) | |||||
| if args.out_dir: | |||||
| out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png") | |||||
| cv2.imwrite(out_fn, raw) | |||||
| post, coords = post_process(raw) | |||||
| for idx2 in [1, 2, 3, 0]: | |||||
| fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.)) | |||||
| fd.write('\n') | |||||
| if args.out_dir: | |||||
| out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png") | |||||
| cv2.imwrite(out_fn, post) | |||||
| def get_args(): | |||||
| parser = argparse.ArgumentParser(description="Outputs binary predictions") | |||||
| parser.add_argument("net_file", | |||||
| help="The deploy.prototxt") | |||||
| parser.add_argument("weight_file", | |||||
| help="The .caffemodel") | |||||
| parser.add_argument("dataset_dir", | |||||
| help="The dataset to be evaluated") | |||||
| parser.add_argument("test_manifest", | |||||
| help="Images to predict") | |||||
| parser.add_argument("out_file", | |||||
| help="output file listing quad regions") | |||||
| parser.add_argument("--out-dir", default='', type=str, | |||||
| help="Dump images") | |||||
| parser.add_argument("--gpu", type=int, default=0, | |||||
| help="GPU to use for running the network") | |||||
| parser.add_argument("-c", "--color", default=False, action='store_true', | |||||
| help="Training batch size") | |||||
| parser.add_argument("-m", "--mean", type=float, default=127., | |||||
| help="Mean value for data preprocessing") | |||||
| parser.add_argument("-s", "--scale", type=float, default=0.0039, | |||||
| help="Optional pixel scale factor") | |||||
| parser.add_argument("--image-size", default=256, type=int, | |||||
| help="Size of images for input to prediction") | |||||
| parser.add_argument("--print-count", default=10, type=int, | |||||
| help="Print interval") | |||||
| args = parser.parse_args() | |||||
| print args | |||||
| return args | |||||
| if __name__ == "__main__": | |||||
| args = get_args() | |||||
| safe_mkdir(args.out_dir) | |||||
| if args.gpu >= 0: | |||||
| caffe.set_device(args.gpu) | |||||
| caffe.set_mode_gpu() | |||||
| else: | |||||
| caffe.set_mode_cpu() | |||||
| main(args) | |||||