From 9681f0d829a95805db279afa104e6e36bda05239 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 18 Nov 2024 09:22:06 -0500 Subject: [PATCH] fix label augmentation and make_network parameter issues --- training/train_multi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/train_multi.py b/training/train_multi.py index e9ae831..01de562 100644 --- a/training/train_multi.py +++ b/training/train_multi.py @@ -175,8 +175,8 @@ def train( moving_image, fixed_image, moving_label, fixed_label = moving_image.cuda(), fixed_image.cuda(), moving_label.cuda(), fixed_label.cuda() if data_augmenter is not None: with torch.no_grad(): - moving_image, fixed_image, moving_image, fixed_image = data_augmenter(moving_image, fixed_image, moving_label, fixed_label) - train_kernel(optimizer, net, moving_image, fixed_image, moving_image, fixed_image, writer, iteration) + moving_image, fixed_image, moving_label, fixed_label = data_augmenter(moving_image, fixed_image, moving_label, fixed_label) + train_kernel(optimizer, net, moving_image, fixed_image, moving_label, fixed_label, writer, iteration) iteration += 1 step_callback(unwrapped_net) @@ -276,7 +276,7 @@ def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eva footsteps.output_dir + "checkpoints/Step_1_final.trch", ) - net_2 = make_network(input_shape, include_last_step=True) + net_2 = make_network(input_shape, include_last_step=True, use_label=True) net_2.regis_net.netPhi.load_state_dict(net.regis_net.state_dict())