Skip to content

Commit

Permalink
fix validation with loss functions taking batch
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 25, 2024
1 parent 8a2d3b8 commit a4127af
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 62 deletions.
85 changes: 28 additions & 57 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,7 @@ def validation_step(
return output

dataset_id = self.get_dataset_id(dataloader_idx)
metrics = self.validate(
output=output,
query_ids=batch.query_ids,
doc_ids=batch.doc_ids,
qrels=batch.qrels,
targets=getattr(batch, "targets", None),
)
metrics = self.validate(output, batch)
for key, value in metrics.items():
key = f"{dataset_id}/{key}"
self.log(key, value, batch_size=len(batch.queries))
Expand Down Expand Up @@ -265,72 +259,48 @@ def get_dataset_id(self, dataloader_idx: int) -> str:
def validate(
self,
output: LightningIROutput,
query_ids: Sequence[str] | None = None,
doc_ids: Sequence[Sequence[str]] | None = None,
qrels: Sequence[Dict[str, int]] | None = None,
targets: torch.Tensor | None = None,
num_docs: Sequence[int] | int | None = None,
batch: TrainBatch | RankBatch | SearchBatch,
) -> Dict[str, float]:
"""Validates the model output with the evaluation metrics and loss functions.
:param output: Model output
:type output: LightningIROutput
:param query_ids: ids of the queries, defaults to None
:type query_ids: Sequence[str] | None, optional
:param doc_ids: ids of the documents, defaults to None
:type doc_ids: Sequence[Sequence[str]] | None, optional
:param qrels: Mappings of doc_id -> relevance for each query, defaults to None
:type qrels: Sequence[Dict[str, int]] | None, optional
:param targets: Target tensor used during fine-tuning, defaults to None
:type targets: torch.Tensor | None, optional
:param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
sequence contains one value per query specifying the number of documents for that query. If an integer,
assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
the number of documents by the number of queries, defaults to None
:raises ValueError: If num_docs can not be parsed and query_ids are not set
:raises ValueError: If num_docs can not be parsed and doc_ids are not set
:return: _description_
:param batch: Batch of validation or testing data
:type batch: TrainBatch | RankBatch | SearchBatch
:return: Dictionary of evaluation metrics
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
if self.evaluation_metrics is None or output.scores is None:
return metrics
if query_ids is None:
if num_docs is None:
raise ValueError("num_docs must be set if query_ids is not set")
query_ids = tuple(str(i) for i in range(len(num_docs)))
if doc_ids is None:
if num_docs is None:
raise ValueError("num_docs must be set if doc_ids is not set")
doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs))
metrics.update(self.validate_metrics(output, query_ids, doc_ids, qrels))
metrics.update(self.validate_loss(output, query_ids, targets))
metrics.update(self.validate_metrics(output, batch))
metrics.update(self.validate_loss(output, batch))
return metrics

def validate_metrics(
self,
output: LightningIROutput,
query_ids: Sequence[str],
doc_ids: Sequence[Sequence[str]],
qrels: Sequence[Dict[str, int]] | None,
batch: TrainBatch | RankBatch | SearchBatch,
) -> Dict[str, float]:
"""Validates the model output with the evaluation metrics.
:param output: Model output
:type output: LightningIROutput
:param query_ids: ids of the queries
:type query_ids: Sequence[str]
:param doc_ids: ids of the documents
:type doc_ids: Sequence[Sequence[str]]
:param qrels: Mappings of doc_id -> relevance for each query, defaults to None
:type qrels: Sequence[Dict[str, int]] | None
:param batch: Batch of validation or testing data
:type batch: TrainBatch | RankBatch | SearchBatch
:return: Evaluation metrics
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
qrels = batch.qrels
if self.evaluation_metrics is None or qrels is None:
return metrics
query_ids = batch.query_ids
doc_ids = batch.doc_ids
if query_ids is None:
raise ValueError("query_ids must be set")
if doc_ids is None:
raise ValueError("doc_ids must be set")
evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
ir_measures_qrels = create_qrels_from_dicts(qrels)
if evaluation_metrics and qrels is not None and output.scores is not None:
Expand All @@ -339,24 +309,27 @@ def validate_metrics(
return metrics

def validate_loss(
self, output: LightningIROutput, query_ids: Sequence[str], targets: torch.Tensor | None
self,
output: LightningIROutput,
batch: TrainBatch | RankBatch | SearchBatch,
) -> Dict[str, float]:
"""Validates the model output with the loss functions.
:param output: Model output
:type output: LightningIROutput
:param query_ids: ids of the queries
:type query_ids: Sequence[str]
:param targets: Target tensor used during fine-tuning
:type targets: torch.Tensor | None
:return: Loss metrics
:param batch: Batch of validation or testing data
:type batch: TrainBatch | RankBatch | SearchBatch
:return: Evaluation metrics
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
query_ids = batch.query_ids
if query_ids is None:
raise ValueError("query_ids must be set")
if (
self.evaluation_metrics is None
or "loss" not in self.evaluation_metrics
or targets is None
or getattr(batch, "targets", None) is None
or self.loss_functions is None
or output.scores is None
):
Expand All @@ -366,9 +339,7 @@ def validate_loss(
# NOTE skip in-batch losses because they can use a lot of memory
if isinstance(loss_function, InBatchLossFunction):
continue
metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(
output, targets
).item()
metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch).item()
return metrics

def on_validation_epoch_end(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def output(batch_size: int, depth: int) -> LightningIROutput:


@pytest.fixture(scope="module")
def labels(batch_size: int, depth: int) -> torch.Tensor:
def targets(batch_size: int, depth: int) -> torch.Tensor:
tensor = torch.randint(0, 5, (batch_size, depth))
return tensor

Expand All @@ -66,11 +66,11 @@ def embeddings(batch_size: int, sequence_length: int, embedding_dim: int) -> tor


@pytest.fixture(scope="module")
def batch(batch_size: int, depth: int) -> TrainBatch:
def batch(batch_size: int, depth: int, targets: torch.Tensor) -> TrainBatch:
return TrainBatch(
queries=["query"] * batch_size,
docs=[[f"doc{i}" for i in range(depth)]] * batch_size,
targets=torch.randint(0, 5, (batch_size, depth)),
targets=targets,
)


Expand All @@ -97,8 +97,8 @@ def batch(batch_size: int, depth: int) -> TrainBatch:
"SupervisedMarginMSE",
],
)
def test_loss_func(output: LightningIROutput, labels: torch.Tensor, loss_func: ScoringLossFunction):
loss = loss_func.compute_loss(output, labels)
def test_loss_func(output: LightningIROutput, batch: TrainBatch, loss_func: ScoringLossFunction):
loss = loss_func.compute_loss(output, batch)
assert loss >= 0
assert loss.requires_grad

Expand Down

0 comments on commit a4127af

Please sign in to comment.