diff --git a/examples/trompt_multi_gpu.py b/examples/trompt_multi_gpu.py index d2467389..b4e1a12e 100644 --- a/examples/trompt_multi_gpu.py +++ b/examples/trompt_multi_gpu.py @@ -65,10 +65,13 @@ def train( with torch.no_grad(): metric.update(out.mean(dim=1).argmax(dim=-1), tf.y) - _, num_layers, num_classes = out.size() + batch_size, num_layers, num_classes = out.size() # [batch_size * num_layers, num_classes] pred = out.view(-1, num_classes) - y = tf.y.repeat_interleave(num_layers) + y = tf.y.repeat_interleave( + num_layers, + output_size=num_layers * batch_size, + ) # Layer-wise logit loss loss = F.cross_entropy(pred, y) loss.backward()