Skip to content

Commit

Permalink
no stream sync
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 27, 2024
1 parent 30ac943 commit 4d07d34
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions examples/trompt_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,13 @@ def train(
with torch.no_grad():
metric.update(out.mean(dim=1).argmax(dim=-1), tf.y)

_, num_layers, num_classes = out.size()
batch_size, num_layers, num_classes = out.size()
# [batch_size * num_layers, num_classes]
pred = out.view(-1, num_classes)
y = tf.y.repeat_interleave(num_layers)
y = tf.y.repeat_interleave(
num_layers,
output_size=num_layers * batch_size,
)
# Layer-wise logit loss
loss = F.cross_entropy(pred, y)
loss.backward()
Expand Down

0 comments on commit 4d07d34

Please sign in to comment.