diff --git a/test_pretrained.py b/test_pretrained.py index e20a265..7588a28 100644 --- a/test_pretrained.py +++ b/test_pretrained.py @@ -37,8 +37,8 @@ def predict(network, im, output_blob, args): def presolve(net, args): - net.blobs["data"].reshape(1, 3, args.image_size, args.image_size) - net.blobs["gt"].reshape(1, 1, args.image_size, args.image_size) + net.blobs["data"].reshape(1, 3, 256, 256) + net.blobs["gt"].reshape(1, 1, 256, 256) def main(args):