From 6ef66973ac35b213405d7d6b9f11292460e0a6b8 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 12 Sep 2017 11:02:03 -0600 Subject: [PATCH] optional out_dir --- test_pretrained.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/test_pretrained.py b/test_pretrained.py index 967a388..1e3d61c 100644 --- a/test_pretrained.py +++ b/test_pretrained.py @@ -63,16 +63,18 @@ def main(args): _input = np.transpose(_input, (2, 0, 1)) raw = (255 * predict(net, _input, 'out', args)).astype(np.uint8) - out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png") - cv2.imwrite(out_fn, raw) + if args.out_dir: + out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_raw.png") + cv2.imwrite(out_fn, raw) post, coords = post_process(raw) for idx2 in [1, 2, 3, 0]: fd.write('%d,%d,' % (width * coords[idx2][0] / 256., height * coords[idx2][1] / 256.)) fd.write('\n') - out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png") - cv2.imwrite(out_fn, post) + if args.out_dir: + out_fn = os.path.join(args.out_dir, f.replace('/','_')[:-4] + "_post.png") + cv2.imwrite(out_fn, post) def get_args(): @@ -87,7 +89,7 @@ def get_args(): parser.add_argument("out_file", type=str, help="Output file") - parser.add_argument("--out-dir", type=str, default='out', + parser.add_argument("--out-dir", type=str, default='', help="") parser.add_argument("--gpu", type=int, default=0, help="GPU to use for running the network") @@ -102,7 +104,8 @@ def get_args(): if __name__ == "__main__": args = get_args() - safe_mkdir(args.out_dir) + if args.out_dir: + safe_mkdir(args.out_dir) if args.model == 'ohio': NET_FILE = './models/ohio_train_val.prototxt'