Skip to content

Commit

Permalink
fix for losses taking batches
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 26, 2024
1 parent 0bf6c3d commit 78e373a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lightning_ir/cross_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List
raise ValueError("scores and targets must be set in the output and batch")

output.scores = output.scores.view(len(batch.query_ids), -1)
targets = batch.targets.view(*output.scores.shape, -1)
batch.targets = batch.targets.view(*output.scores.shape, -1)

losses = []
for loss_function, _ in self.loss_functions:
if not isinstance(loss_function, ScoringLossFunction):
raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function")
losses.append(loss_function.compute_loss(output, targets))
losses.append(loss_function.compute_loss(output, batch))
return losses

0 comments on commit 78e373a

Please sign in to comment.