diff --git a/lightning_ir/loss/loss.py b/lightning_ir/loss/loss.py index bb27ace..20a9b5e 100644 --- a/lightning_ir/loss/loss.py +++ b/lightning_ir/loss/loss.py @@ -61,7 +61,7 @@ def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Ten target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] else: raise ValueError("invalid margin type") - loss = torch.nn.functional.mse_loss(margin, target_margin.clamp(min=0)) + loss = torch.nn.functional.mse_loss(margin, target_margin) return loss