From a4127afe3eda966febb9f21bac3a05c21896c64f Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Mon, 25 Nov 2024 17:22:23 +0100 Subject: [PATCH] fix validation with loss functions taking batch --- lightning_ir/base/module.py | 85 ++++++++++++------------------------- tests/test_loss.py | 10 ++--- 2 files changed, 33 insertions(+), 62 deletions(-) diff --git a/lightning_ir/base/module.py b/lightning_ir/base/module.py index 801ab44..2eebfb0 100644 --- a/lightning_ir/base/module.py +++ b/lightning_ir/base/module.py @@ -212,13 +212,7 @@ def validation_step( return output dataset_id = self.get_dataset_id(dataloader_idx) - metrics = self.validate( - output=output, - query_ids=batch.query_ids, - doc_ids=batch.doc_ids, - qrels=batch.qrels, - targets=getattr(batch, "targets", None), - ) + metrics = self.validate(output, batch) for key, value in metrics.items(): key = f"{dataset_id}/{key}" self.log(key, value, batch_size=len(batch.queries)) @@ -265,72 +259,48 @@ def get_dataset_id(self, dataloader_idx: int) -> str: def validate( self, output: LightningIROutput, - query_ids: Sequence[str] | None = None, - doc_ids: Sequence[Sequence[str]] | None = None, - qrels: Sequence[Dict[str, int]] | None = None, - targets: torch.Tensor | None = None, - num_docs: Sequence[int] | int | None = None, + batch: TrainBatch | RankBatch | SearchBatch, ) -> Dict[str, float]: """Validates the model output with the evaluation metrics and loss functions. :param output: Model output :type output: LightningIROutput - :param query_ids: ids of the queries, defaults to None - :type query_ids: Sequence[str] | None, optional - :param doc_ids: ids of the documents, defaults to None - :type doc_ids: Sequence[Sequence[str]] | None, optional - :param qrels: Mappings of doc_id -> relevance for each query, defaults to None - :type qrels: Sequence[Dict[str, int]] | None, optional - :param targets: Target tensor used during fine-tuning, defaults to None - :type targets: torch.Tensor | None, optional - :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` - should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the - sequence contains one value per query specifying the number of documents for that query. If an integer, - assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing - the number of documents by the number of queries, defaults to None - :raises ValueError: If num_docs can not be parsed and query_ids are not set - :raises ValueError: If num_docs can not be parsed and doc_ids are not set - :return: _description_ + :param batch: Batch of validation or testing data + :type batch: TrainBatch | RankBatch | SearchBatch + :return: Dictionary of evaluation metrics :rtype: Dict[str, float] """ metrics: Dict[str, float] = {} if self.evaluation_metrics is None or output.scores is None: return metrics - if query_ids is None: - if num_docs is None: - raise ValueError("num_docs must be set if query_ids is not set") - query_ids = tuple(str(i) for i in range(len(num_docs))) - if doc_ids is None: - if num_docs is None: - raise ValueError("num_docs must be set if doc_ids is not set") - doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs)) - metrics.update(self.validate_metrics(output, query_ids, doc_ids, qrels)) - metrics.update(self.validate_loss(output, query_ids, targets)) + metrics.update(self.validate_metrics(output, batch)) + metrics.update(self.validate_loss(output, batch)) return metrics def validate_metrics( self, output: LightningIROutput, - query_ids: Sequence[str], - doc_ids: Sequence[Sequence[str]], - qrels: Sequence[Dict[str, int]] | None, + batch: TrainBatch | RankBatch | SearchBatch, ) -> Dict[str, float]: """Validates the model output with the evaluation metrics. :param output: Model output :type output: LightningIROutput - :param query_ids: ids of the queries - :type query_ids: Sequence[str] - :param doc_ids: ids of the documents - :type doc_ids: Sequence[Sequence[str]] - :param qrels: Mappings of doc_id -> relevance for each query, defaults to None - :type qrels: Sequence[Dict[str, int]] | None + :param batch: Batch of validation or testing data + :type batch: TrainBatch | RankBatch | SearchBatch :return: Evaluation metrics :rtype: Dict[str, float] """ metrics: Dict[str, float] = {} + qrels = batch.qrels if self.evaluation_metrics is None or qrels is None: return metrics + query_ids = batch.query_ids + doc_ids = batch.doc_ids + if query_ids is None: + raise ValueError("query_ids must be set") + if doc_ids is None: + raise ValueError("doc_ids must be set") evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] ir_measures_qrels = create_qrels_from_dicts(qrels) if evaluation_metrics and qrels is not None and output.scores is not None: @@ -339,24 +309,27 @@ def validate_metrics( return metrics def validate_loss( - self, output: LightningIROutput, query_ids: Sequence[str], targets: torch.Tensor | None + self, + output: LightningIROutput, + batch: TrainBatch | RankBatch | SearchBatch, ) -> Dict[str, float]: """Validates the model output with the loss functions. :param output: Model output :type output: LightningIROutput - :param query_ids: ids of the queries - :type query_ids: Sequence[str] - :param targets: Target tensor used during fine-tuning - :type targets: torch.Tensor | None - :return: Loss metrics + :param batch: Batch of validation or testing data + :type batch: TrainBatch | RankBatch | SearchBatch + :return: Evaluation metrics :rtype: Dict[str, float] """ metrics: Dict[str, float] = {} + query_ids = batch.query_ids + if query_ids is None: + raise ValueError("query_ids must be set") if ( self.evaluation_metrics is None or "loss" not in self.evaluation_metrics - or targets is None + or getattr(batch, "targets", None) is None or self.loss_functions is None or output.scores is None ): @@ -366,9 +339,7 @@ def validate_loss( # NOTE skip in-batch losses because they can use a lot of memory if isinstance(loss_function, InBatchLossFunction): continue - metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss( - output, targets - ).item() + metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch).item() return metrics def on_validation_epoch_end(self) -> None: diff --git a/tests/test_loss.py b/tests/test_loss.py index d648b1c..7a3259f 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -54,7 +54,7 @@ def output(batch_size: int, depth: int) -> LightningIROutput: @pytest.fixture(scope="module") -def labels(batch_size: int, depth: int) -> torch.Tensor: +def targets(batch_size: int, depth: int) -> torch.Tensor: tensor = torch.randint(0, 5, (batch_size, depth)) return tensor @@ -66,11 +66,11 @@ def embeddings(batch_size: int, sequence_length: int, embedding_dim: int) -> tor @pytest.fixture(scope="module") -def batch(batch_size: int, depth: int) -> TrainBatch: +def batch(batch_size: int, depth: int, targets: torch.Tensor) -> TrainBatch: return TrainBatch( queries=["query"] * batch_size, docs=[[f"doc{i}" for i in range(depth)]] * batch_size, - targets=torch.randint(0, 5, (batch_size, depth)), + targets=targets, ) @@ -97,8 +97,8 @@ def batch(batch_size: int, depth: int) -> TrainBatch: "SupervisedMarginMSE", ], ) -def test_loss_func(output: LightningIROutput, labels: torch.Tensor, loss_func: ScoringLossFunction): - loss = loss_func.compute_loss(output, labels) +def test_loss_func(output: LightningIROutput, batch: TrainBatch, loss_func: ScoringLossFunction): + loss = loss_func.compute_loss(output, batch) assert loss >= 0 assert loss.requires_grad