Skip to content

Commit

Permalink
Merge pull request #29 from uncbiag/fix-multi-training
Browse files Browse the repository at this point in the history
fix label augmentation and make_network parameter issues
  • Loading branch information
HastingsGreer authored Nov 19, 2024
2 parents a9b1310 + 9681f0d commit 8a11d35
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions training/train_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def train(
moving_image, fixed_image, moving_label, fixed_label = moving_image.cuda(), fixed_image.cuda(), moving_label.cuda(), fixed_label.cuda()
if data_augmenter is not None:
with torch.no_grad():
moving_image, fixed_image, moving_image, fixed_image = data_augmenter(moving_image, fixed_image, moving_label, fixed_label)
train_kernel(optimizer, net, moving_image, fixed_image, moving_image, fixed_image, writer, iteration)
moving_image, fixed_image, moving_label, fixed_label = data_augmenter(moving_image, fixed_image, moving_label, fixed_label)
train_kernel(optimizer, net, moving_image, fixed_image, moving_label, fixed_label, writer, iteration)
iteration += 1

step_callback(unwrapped_net)
Expand Down Expand Up @@ -276,7 +276,7 @@ def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eva
footsteps.output_dir + "checkpoints/Step_1_final.trch",
)

net_2 = make_network(input_shape, include_last_step=True)
net_2 = make_network(input_shape, include_last_step=True, use_label=True)

net_2.regis_net.netPhi.load_state_dict(net.regis_net.state_dict())

Expand Down

0 comments on commit 8a11d35

Please sign in to comment.