Skip to content

Commit

Permalink
fix scored batch loss function + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 21, 2024
1 parent 72705b8 commit d426d24
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
4 changes: 2 additions & 2 deletions lightning_ir/bi_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[to

num_queries = len(batch.queries)
output.scores = output.scores.view(num_queries, -1)
targets = batch.targets.view(*output.scores.shape, -1)
batch.targets = batch.targets.view(*output.scores.shape, -1)
losses = []
for loss_function, _ in self.loss_functions:
if isinstance(loss_function, InBatchLossFunction):
Expand All @@ -157,7 +157,7 @@ def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[to
elif isinstance(loss_function, EmbeddingLossFunction):
losses.append(loss_function.compute_loss(output))
elif isinstance(loss_function, ScoringLossFunction):
losses.append(loss_function.compute_loss(output, targets))
losses.append(loss_function.compute_loss(output, batch))
else:
raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
if self.config.sparsification is not None:
Expand Down
55 changes: 29 additions & 26 deletions lightning_ir/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import torch

from lightning_ir.data import TrainBatch

if TYPE_CHECKING:
from ..base import LightningIROutput
from ..bi_encoder import BiEncoderOutput
Expand All @@ -22,15 +20,18 @@ def process_scores(self, output: LightningIROutput) -> torch.Tensor:
raise ValueError("Expected scores in LightningIROutput")
return output.scores

def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor:
targets = batch.targets
if targets is None:
raise ValueError("Expected targets in TrainBatch")
if targets.ndim > scores.ndim:
return targets.max(-1).values
return targets


class ScoringLossFunction(LossFunction):
@abstractmethod
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ...
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: ...


class EmbeddingLossFunction(LossFunction):
Expand All @@ -52,9 +53,9 @@ class MarginMSE(PairwiseLossFunction):
def __init__(self, margin: float | Literal["scores"] = 1.0):
self.margin = margin

def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
pos = scores[query_idcs, pos_idcs]
neg = scores[query_idcs, neg_idcs]
Expand All @@ -80,9 +81,9 @@ def __init__(self):


class RankNet(PairwiseLossFunction):
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
pos = scores[query_idcs, pos_idcs]
neg = scores[query_idcs, neg_idcs]
Expand All @@ -92,19 +93,19 @@ def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torc


class KLDivergence(ListwiseLossFunction):
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
scores = torch.nn.functional.log_softmax(scores, dim=-1)
targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1)
loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean")
return loss


class LocalizedContrastiveEstimation(ListwiseLossFunction):
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
targets = targets.argmax(dim=1)
loss = torch.nn.functional.cross_entropy(scores, targets)
return loss
Expand Down Expand Up @@ -166,10 +167,10 @@ def get_ndcg(
ndcg = dcg / (idcg.clamp(min=1e-12))
return ndcg

def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
approx_ranks = self.get_approx_ranks(scores, self.temperature)
ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
loss = 1 - ndcg
Expand All @@ -190,9 +191,9 @@ def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) ->
mrr = mrr.max(dim=-1)[0]
return mrr

def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
approx_ranks = self.get_approx_ranks(scores, self.temperature)
mrr = self.get_mrr(approx_ranks, targets, k=None)
loss = 1 - mrr
Expand All @@ -208,9 +209,9 @@ def __init__(
super().__init__(temperature)
self.discount = discount

def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = self.process_targets(scores, batch)
approx_ranks = self.get_approx_ranks(scores, self.temperature)
ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none")
Expand Down Expand Up @@ -361,9 +362,9 @@ def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = Non
def _sort_mask(
self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch
) -> torch.Tensor:
assert output.scores is not None and batch.targets is not None
targets = self.process_targets(output.scores, batch.targets)
idcs = targets.argsort(descending=True).argsort()
scores = self.process_scores(output)
targets = self.process_targets(scores, batch)
idcs = targets.argsort(descending=True).argsort().cpu()
idcs = idcs + torch.arange(num_queries)[:, None] * num_docs
block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
return mask.scatter(1, block_idcs, mask.gather(1, idcs))
Expand Down Expand Up @@ -392,9 +393,10 @@ def _get_neg_mask(
) -> torch.Tensor:
neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch)
assert batch.targets is not None
max_score, _ = batch.targets.max(dim=-1, keepdim=True)
score_diff = max_score - batch.targets
scores = self.process_scores(output)
targets = self.process_targets(scores, batch).cpu()
max_score, _ = targets.max(dim=-1, keepdim=True)
score_diff = (max_score - targets).cpu()
score_mask = score_diff.ge(self.min_target_diff)
block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
neg_mask = neg_mask.scatter(1, block_idcs, score_mask)
Expand All @@ -404,9 +406,10 @@ def _get_neg_mask(
additional_neg_samples = num_neg_samples - min_num_neg_samples
for query_idx, neg_samples in enumerate(additional_neg_samples):
neg_idcs = neg_mask[query_idx].nonzero().squeeze(1)
additional_neg_idcs = torch.randperm(neg_idcs.shape[0])[:neg_samples]
additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples]
assert neg_mask[query_idx, additional_neg_idcs].all().item()
neg_mask[query_idx, additional_neg_idcs] = False
assert neg_mask.sum(dim=1).eq(min_num_neg_samples).all()
assert neg_mask[query_idx].sum().eq(min_num_neg_samples).item()
return neg_mask


Expand Down

0 comments on commit d426d24

Please sign in to comment.