Skip to content

Commit

Permalink
Remove Label smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
zrthxn committed May 18, 2024
1 parent 842e1c8 commit 19653d7
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(head: torch.nn.Module,
vector, _ = frozen_model.encode(seqs.src_tokens, padding_mask=mask)
probs = head(vector)

loss = torch.nn.functional.binary_cross_entropy(probs, labels, label_smoothing=label_smoothing)
loss = torch.nn.functional.binary_cross_entropy(probs, labels)
if loss.isnan().any().item():
logger.error(seqs); logger.error(labels)
raise RuntimeError("Train loss is NaN! Something is wrong in the model!")
Expand Down

0 comments on commit 19653d7

Please sign in to comment.