From 34b7be30a073939c52bf288c1afd658f05a8188d Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Wed, 30 Oct 2024 15:33:40 +0100 Subject: [PATCH] pass model output to loss functions --- lightning_ir/base/module.py | 35 ++++++------ lightning_ir/bi_encoder/module.py | 34 +++++++----- lightning_ir/cross_encoder/module.py | 15 +++-- lightning_ir/loss/loss.py | 83 +++++++++++++++++----------- tests/test_loss.py | 34 +++++------- 5 files changed, 109 insertions(+), 92 deletions(-) diff --git a/lightning_ir/base/module.py b/lightning_ir/base/module.py index 19916a3..184f7ff 100644 --- a/lightning_ir/base/module.py +++ b/lightning_ir/base/module.py @@ -211,7 +211,7 @@ def validation_step( dataset_id = self.get_dataset_id(dataloader_idx) metrics = self.validate( - scores=output.scores, + output=output, query_ids=batch.query_ids, doc_ids=batch.doc_ids, qrels=batch.qrels, @@ -262,7 +262,7 @@ def get_dataset_id(self, dataloader_idx: int) -> str: def validate( self, - scores: torch.Tensor | None = None, + output: LightningIROutput, query_ids: Sequence[str] | None = None, doc_ids: Sequence[Sequence[str]] | None = None, qrels: Sequence[Dict[str, int]] | None = None, @@ -271,8 +271,8 @@ def validate( ) -> Dict[str, float]: """Validates the model output with the evaluation metrics and loss functions. - :param scores: Model output scores, defaults to None - :type scores: torch.Tensor | None, optional + :param output: Model output + :type output: LightningIROutput :param query_ids: ids of the queries, defaults to None :type query_ids: Sequence[str] | None, optional :param doc_ids: ids of the documents, defaults to None @@ -289,7 +289,7 @@ def validate( :rtype: Dict[str, float] """ metrics: Dict[str, float] = {} - if self.evaluation_metrics is None or scores is None: + if self.evaluation_metrics is None or output.scores is None: return metrics if query_ids is None: if num_docs is None: @@ -299,21 +299,21 @@ def validate( if num_docs is None: raise ValueError("num_docs must be set if doc_ids is not set") doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs)) - metrics.update(self.validate_metrics(scores, query_ids, doc_ids, qrels)) - metrics.update(self.validate_loss(scores, query_ids, targets)) + metrics.update(self.validate_metrics(output, query_ids, doc_ids, qrels)) + metrics.update(self.validate_loss(output, query_ids, targets)) return metrics def validate_metrics( self, - scores: torch.Tensor, + output: LightningIROutput, query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], qrels: Sequence[Dict[str, int]] | None, ) -> Dict[str, float]: """Validates the model output with the evaluation metrics. - :param scores: Model output scores - :type scores: torch.Tensor + :param output: Model output + :type output: LightningIROutput :param query_ids: ids of the queries :type query_ids: Sequence[str] :param doc_ids: ids of the documents @@ -328,18 +328,18 @@ def validate_metrics( return metrics evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] ir_measures_qrels = create_qrels_from_dicts(qrels) - if evaluation_metrics and qrels is not None: - run = create_run_from_scores(query_ids, doc_ids, scores) + if evaluation_metrics and qrels is not None and output.scores is not None: + run = create_run_from_scores(query_ids, doc_ids, output.scores) metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) return metrics def validate_loss( - self, scores: torch.Tensor, query_ids: Sequence[str], targets: torch.Tensor | None + self, output: LightningIROutput, query_ids: Sequence[str], targets: torch.Tensor | None ) -> Dict[str, float]: """Validates the model output with the loss functions. - :param scores: Model output scores - :type scores: torch.Tensor + :param output: Model output + :type output: LightningIROutput :param query_ids: ids of the queries :type query_ids: Sequence[str] :param targets: Target tensor used during fine-tuning @@ -353,15 +353,16 @@ def validate_loss( or "loss" not in self.evaluation_metrics or targets is None or self.loss_functions is None + or output.scores is None ): return metrics - scores = scores.view(len(query_ids), -1) + output.scores = output.scores.view(len(query_ids), -1) for loss_function, _ in self.loss_functions: # NOTE skip in-batch losses because they can use a lot of memory if isinstance(loss_function, InBatchLossFunction): continue metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss( - scores, targets + output, targets ).item() return metrics diff --git a/lightning_ir/bi_encoder/module.py b/lightning_ir/bi_encoder/module.py index 1551717..1495301 100644 --- a/lightning_ir/bi_encoder/module.py +++ b/lightning_ir/bi_encoder/module.py @@ -3,7 +3,7 @@ import torch -from ..base import LightningIRModule +from ..base import LightningIRModule, LightningIROutput from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction from ..retrieve import SearchConfig, Searcher @@ -79,36 +79,40 @@ def compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[tor if self.loss_functions is None: raise ValueError("Loss function is not set") - scores = output.scores - query_embeddings = output.query_embeddings - doc_embeddings = output.doc_embeddings - if batch.targets is None or query_embeddings is None or doc_embeddings is None or scores is None: + if ( + batch.targets is None + or output.query_embeddings is None + or output.doc_embeddings is None + or output.scores is None + ): raise ValueError( "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch" ) num_queries = len(batch.queries) - scores = scores.view(num_queries, -1) - targets = batch.targets.view(*scores.shape, -1) + output.scores = output.scores.view(num_queries, -1) + targets = batch.targets.view(*output.scores.shape, -1) losses = [] for loss_function, _ in self.loss_functions: if isinstance(loss_function, InBatchLossFunction): - pos_idcs, neg_idcs = loss_function.get_ib_idcs(*scores.shape) - ib_doc_embeddings = self.get_ib_doc_embeddings(doc_embeddings, pos_idcs, neg_idcs, num_queries) - ib_scores = self.model.score(query_embeddings, ib_doc_embeddings) + pos_idcs, neg_idcs = loss_function.get_ib_idcs(*output.scores.shape) + ib_doc_embeddings = self.get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries) + ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings) ib_scores = ib_scores.view(num_queries, -1) - losses.append(loss_function.compute_loss(ib_scores)) + losses.append(loss_function.compute_loss(LightningIROutput(ib_scores))) elif isinstance(loss_function, EmbeddingLossFunction): - losses.append(loss_function.compute_loss(query_embeddings, doc_embeddings)) + losses.append(loss_function.compute_loss(output)) elif isinstance(loss_function, ScoringLossFunction): - losses.append(loss_function.compute_loss(scores, targets)) + losses.append(loss_function.compute_loss(output, targets)) else: raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}") if self.config.sparsification is not None: query_num_nonzero = ( - torch.nonzero(query_embeddings.embeddings).shape[0] / query_embeddings.embeddings.shape[0] + torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0] + ) + doc_num_nonzero = ( + torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0] ) - doc_num_nonzero = torch.nonzero(doc_embeddings.embeddings).shape[0] / doc_embeddings.embeddings.shape[0] self.log("query_num_nonzero", query_num_nonzero) self.log("doc_num_nonzero", doc_num_nonzero) return losses diff --git a/lightning_ir/cross_encoder/module.py b/lightning_ir/cross_encoder/module.py index 152bd93..d2e019e 100644 --- a/lightning_ir/cross_encoder/module.py +++ b/lightning_ir/cross_encoder/module.py @@ -4,7 +4,7 @@ from ..base.module import LightningIRModule from ..data import RankBatch, SearchBatch, TrainBatch -from ..loss.loss import InBatchLossFunction, LossFunction +from ..loss.loss import LossFunction, ScoringLossFunction from .config import CrossEncoderConfig from .model import CrossEncoderModel, CrossEncoderOutput from .tokenizer import CrossEncoderTokenizer @@ -38,16 +38,15 @@ def compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[ if self.loss_functions is None: raise ValueError("loss_functions must be set in the module") output = self.forward(batch) - scores = output.scores - if scores is None or batch.targets is None: + if output.scores is None or batch.targets is None: raise ValueError("scores and targets must be set in the output and batch") - scores = scores.view(len(batch.query_ids), -1) - targets = batch.targets.view(*scores.shape, -1) + output.scores = output.scores.view(len(batch.query_ids), -1) + targets = batch.targets.view(*output.scores.shape, -1) losses = [] for loss_function, _ in self.loss_functions: - if isinstance(loss_function, InBatchLossFunction): - raise NotImplementedError("InBatchLossFunction not implemented for cross-encoders") - losses.append(loss_function.compute_loss(scores, targets)) + if not isinstance(loss_function, ScoringLossFunction): + raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function") + losses.append(loss_function.compute_loss(output, targets)) return losses diff --git a/lightning_ir/loss/loss.py b/lightning_ir/loss/loss.py index 20a9b5e..d8c9a31 100644 --- a/lightning_ir/loss/loss.py +++ b/lightning_ir/loss/loss.py @@ -6,21 +6,23 @@ import torch if TYPE_CHECKING: - from ..bi_encoder import BiEncoderEmbedding + from ..base import LightningIROutput + from ..bi_encoder import BiEncoderOutput class LossFunction(ABC): @abstractmethod - def compute_loss(self, *args, **kwargs) -> torch.Tensor: ... + def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: ... + + def process_scores(self, output: LightningIROutput) -> torch.Tensor: + if output.scores is None: + raise ValueError("Expected scores in LightningIROutput") + return output.scores class ScoringLossFunction(LossFunction): @abstractmethod - def compute_loss( - self, - scores: torch.Tensor, - targets: torch.Tensor, - ) -> torch.Tensor: ... + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ... def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: if targets.ndim > scores.ndim: @@ -30,9 +32,7 @@ def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch. class EmbeddingLossFunction(LossFunction): @abstractmethod - def compute_loss( - self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding - ) -> torch.Tensor: ... + def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ... class PairwiseLossFunction(ScoringLossFunction): @@ -49,7 +49,8 @@ class MarginMSE(PairwiseLossFunction): def __init__(self, margin: float | Literal["scores"] = 1.0): self.margin = margin - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) targets = self.process_targets(scores, targets) query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) pos = scores[query_idcs, pos_idcs] @@ -76,7 +77,8 @@ def __init__(self): class RankNet(PairwiseLossFunction): - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) targets = self.process_targets(scores, targets) query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) pos = scores[query_idcs, pos_idcs] @@ -87,7 +89,8 @@ def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Ten class KLDivergence(ListwiseLossFunction): - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) targets = self.process_targets(scores, targets) scores = torch.nn.functional.log_softmax(scores, dim=-1) targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1) @@ -96,7 +99,8 @@ def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Ten class LocalizedContrastiveEstimation(ListwiseLossFunction): - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) targets = self.process_targets(scores, targets) targets = targets.argmax(dim=1) loss = torch.nn.functional.cross_entropy(scores, targets) @@ -159,7 +163,9 @@ def get_ndcg( ndcg = dcg / (idcg.clamp(min=1e-12)) return ndcg - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) + scores = self.process_scores(output) targets = self.process_targets(scores, targets) approx_ranks = self.get_approx_ranks(scores, self.temperature) ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains) @@ -181,7 +187,8 @@ def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> mrr = mrr.max(dim=-1)[0] return mrr - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) targets = self.process_targets(scores, targets) approx_ranks = self.get_approx_ranks(scores, self.temperature) mrr = self.get_mrr(approx_ranks, targets, k=None) @@ -198,7 +205,8 @@ def __init__( super().__init__(temperature) self.discount = discount - def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: + scores = self.process_scores(output) targets = self.process_targets(scores, targets) approx_ranks = self.get_approx_ranks(scores, self.temperature) ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1 @@ -262,7 +270,7 @@ def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> tor return pred_sorted_targets -class InBatchLossFunction(ScoringLossFunction): +class InBatchLossFunction(LossFunction): def __init__( self, pos_sampling_technique: Literal["all", "first"] = "all", @@ -305,12 +313,13 @@ def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, to neg_idcs = neg_idcs.reshape(-1) return pos_idcs, neg_idcs - def compute_loss(self, scores: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("InBatchLossFunction.compute_loss must be implemented by subclasses") + def compute_loss(self, output: LightningIROutput) -> torch.Tensor: + return super().compute_loss(output) class InBatchCrossEntropy(InBatchLossFunction): - def compute_loss(self, scores: torch.Tensor) -> torch.Tensor: + def compute_loss(self, output: LightningIROutput) -> torch.Tensor: + scores = self.process_scores(output) targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) loss = torch.nn.functional.cross_entropy(scores, targets) return loss @@ -321,27 +330,39 @@ def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None self.query_weight = query_weight self.doc_weight = doc_weight + def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]: + query_embeddings = output.query_embeddings + doc_embeddings = output.doc_embeddings + if query_embeddings is None: + raise ValueError("Expected query_embeddings in BiEncoderOutput") + if doc_embeddings is None: + raise ValueError("Expected doc_embeddings in BiEncoderOutput") + return query_embeddings.embeddings, doc_embeddings.embeddings + class L2Regularization(RegularizationLossFunction): - def compute_loss(self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding) -> torch.Tensor: - query_loss = self.query_weight * query_embeddings.embeddings.norm(dim=-1).mean() - doc_loss = self.doc_weight * doc_embeddings.embeddings.norm(dim=-1).mean() + def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: + query_embeddings, doc_embeddings = self.process_embeddings(output) + query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean() + doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean() loss = query_loss + doc_loss return loss class L1Regularization(RegularizationLossFunction): - def compute_loss(self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding) -> torch.Tensor: - query_loss = self.query_weight * query_embeddings.embeddings.norm(p=1, dim=-1).mean() - doc_loss = self.doc_weight * doc_embeddings.embeddings.norm(p=1, dim=-1).mean() + def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: + query_embeddings, doc_embeddings = self.process_embeddings(output) + query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean() + doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean() loss = query_loss + doc_loss return loss class FLOPSRegularization(RegularizationLossFunction): - def compute_loss(self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding) -> torch.Tensor: - query_loss = torch.sum(torch.mean(torch.abs(query_embeddings.embeddings), dim=0) ** 2) - doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings.embeddings), dim=0) ** 2) - anti_zero = 1 / (torch.sum(query_embeddings.embeddings) ** 2) + 1 / (torch.sum(doc_embeddings.embeddings) ** 2) + def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: + query_embeddings, doc_embeddings = self.process_embeddings(output) + query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2) + doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2) + anti_zero = 1 / (torch.sum(query_embeddings) ** 2) + 1 / (torch.sum(doc_embeddings) ** 2) loss = self.query_weight * query_loss + self.doc_weight * doc_loss + anti_zero return loss diff --git a/tests/test_loss.py b/tests/test_loss.py index 8334ec3..8342111 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -3,7 +3,8 @@ import pytest import torch -from lightning_ir.bi_encoder.model import BiEncoderEmbedding +from lightning_ir.base.model import LightningIROutput +from lightning_ir.bi_encoder.model import BiEncoderEmbedding, BiEncoderOutput from lightning_ir.loss.loss import ( ApproxMRR, ApproxNDCG, @@ -46,9 +47,8 @@ def embedding_dim() -> int: @pytest.fixture(scope="module") -def scores(batch_size: int, depth: int) -> torch.Tensor: - tensor = torch.randn((batch_size, depth), requires_grad=True) - return tensor +def output(batch_size: int, depth: int) -> LightningIROutput: + return LightningIROutput(torch.randn((batch_size, depth), requires_grad=True)) @pytest.fixture(scope="module") @@ -76,36 +76,28 @@ def embeddings(batch_size: int, sequence_length: int, embedding_dim: int) -> tor SupervisedMarginMSE, ], ) -def test_loss_func( - scores: torch.Tensor, - labels: torch.Tensor, - LossFunc: Type[ScoringLossFunction], -): +def test_loss_func(output: LightningIROutput, labels: torch.Tensor, LossFunc: Type[ScoringLossFunction]): loss_func = LossFunc() - loss = loss_func.compute_loss(scores, labels) + loss = loss_func.compute_loss(output, labels) assert loss >= 0 assert loss.requires_grad -@pytest.mark.parametrize( - "InBatchLossFunc", - [InBatchCrossEntropy], -) -def test_in_batch_loss_func(InBatchLossFunc: Type[InBatchLossFunction], scores: torch.Tensor): +@pytest.mark.parametrize("InBatchLossFunc", [InBatchCrossEntropy]) +def test_in_batch_loss_func(InBatchLossFunc: Type[InBatchLossFunction], output: LightningIROutput): loss_func = InBatchLossFunc() - loss = loss_func.compute_loss(scores) + loss = loss_func.compute_loss(output) assert loss >= 0 assert loss.requires_grad @pytest.mark.parametrize("RegularizationLossFunc", [L1Regularization, L2Regularization, FLOPSRegularization]) -def test_regularization_loss_func( - RegularizationLossFunc: Type[RegularizationLossFunction], - embeddings: torch.Tensor, -): +def test_regularization_loss_func(RegularizationLossFunc: Type[RegularizationLossFunction], embeddings: torch.Tensor): loss_func = RegularizationLossFunc() loss = loss_func.compute_loss( - BiEncoderEmbedding(embeddings, torch.empty(0)), BiEncoderEmbedding(embeddings, torch.empty(0)) + BiEncoderOutput( + None, BiEncoderEmbedding(embeddings, torch.empty(0)), BiEncoderEmbedding(embeddings, torch.empty(0)) + ) ) assert loss >= 0 assert loss.requires_grad