| @@ -63,16 +63,18 @@ def main(args): | |||||
| _input = np.transpose(_input, (2, 0, 1)) | _input = np.transpose(_input, (2, 0, 1)) | ||||
| raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8) | raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8) | ||||
| out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png") | |||||
| cv2.imwrite(out_fn, raw) | |||||
| if args.out_dir: | |||||
| out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png") | |||||
| cv2.imwrite(out_fn, raw) | |||||
| post, coords = post_process(raw) | post, coords = post_process(raw) | ||||
| for idx2 in [1, 2, 3, 0]: | for idx2 in [1, 2, 3, 0]: | ||||
| fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.)) | fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.)) | ||||
| fd.write('\n') | fd.write('\n') | ||||
| out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png") | |||||
| cv2.imwrite(out_fn, post) | |||||
| if args.out_dir: | |||||
| out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png") | |||||
| cv2.imwrite(out_fn, post) | |||||
| def get_args(): | def get_args(): | ||||
| @@ -87,7 +89,7 @@ def get_args(): | |||||
| parser.add_argument("out_file", type=str, | parser.add_argument("out_file", type=str, | ||||
| help="Output file") | help="Output file") | ||||
| parser.add_argument("--out-dir", type=str, default='out', | |||||
| parser.add_argument("--out-dir", type=str, default='', | |||||
| help="") | help="") | ||||
| parser.add_argument("--gpu", type=int, default=0, | parser.add_argument("--gpu", type=int, default=0, | ||||
| help="GPU to use for running the network") | help="GPU to use for running the network") | ||||
| @@ -102,7 +104,8 @@ def get_args(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| args = get_args() | args = get_args() | ||||
| safe_mkdir(args.out_dir) | |||||
| if args.out_dir: | |||||
| safe_mkdir(args.out_dir) | |||||
| if args.model == 'ohio': | if args.model == 'ohio': | ||||
| NET_FILE = './models/ohio_train_val.prototxt' | NET_FILE = './models/ohio_train_val.prototxt' | ||||