From 78e373a88565434de38aa841b021237d3d027c41 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Tue, 26 Nov 2024 16:06:43 +0100 Subject: [PATCH] fix for losses taking batches --- lightning_ir/cross_encoder/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightning_ir/cross_encoder/module.py b/lightning_ir/cross_encoder/module.py index 7d40ab3..8ac6615 100644 --- a/lightning_ir/cross_encoder/module.py +++ b/lightning_ir/cross_encoder/module.py @@ -73,11 +73,11 @@ def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List raise ValueError("scores and targets must be set in the output and batch") output.scores = output.scores.view(len(batch.query_ids), -1) - targets = batch.targets.view(*output.scores.shape, -1) + batch.targets = batch.targets.view(*output.scores.shape, -1) losses = [] for loss_function, _ in self.loss_functions: if not isinstance(loss_function, ScoringLossFunction): raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function") - losses.append(loss_function.compute_loss(output, targets)) + losses.append(loss_function.compute_loss(output, batch)) return losses