Skip to content

Commit

Permalink
add option to weight loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 30, 2024
1 parent ca1e8d0 commit 6d3b06a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 47 deletions.
33 changes: 12 additions & 21 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@

from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Sequence
from typing import TYPE_CHECKING, Any, Dict, Mapping, Sequence

import torch
from lightning import LightningModule
from transformers import AutoConfig, AutoModel, BatchEncoding

from ..loss.loss import InBatchLossFunction, LossFunction
from . import (
LightningIRConfig,
LightningIRModel,
LightningIRModelClassFactory,
LightningIROutput,
)
from .validation_utils import (
create_qrels_from_dicts,
create_run_from_scores,
evaluate_run,
)
from . import LightningIRConfig, LightningIRModel, LightningIRModelClassFactory, LightningIROutput
from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run

if TYPE_CHECKING:
from ..data import RankBatch, TrainBatch
Expand All @@ -31,7 +22,7 @@ def __init__(
model_name_or_path: str | None = None,
config: LightningIRConfig | None = None,
model: LightningIRModel | None = None,
loss_functions: Sequence[LossFunction] | None = None,
loss_functions: Sequence[LossFunction] | Mapping[LossFunction, float] | None = None,
evaluation_metrics: Sequence[str] | None = None,
):
super().__init__()
Expand All @@ -52,6 +43,8 @@ def __init__(

self.model: LightningIRModel = model
self.config = self.model.config
if loss_functions is not None and not isinstance(loss_functions, dict):
loss_functions = {loss_function: 1.0 for loss_function in loss_functions}
self.loss_functions = loss_functions
self.evaluation_metrics = evaluation_metrics
self.tokenizer = self.config.__class__.tokenizer_class.from_pretrained(
Expand Down Expand Up @@ -89,19 +82,17 @@ def prepare_input(
encodings[key] = encodings[key].to(self.device)
return encodings

def compute_losses(
self,
batch: TrainBatch,
loss_functions: Sequence[LossFunction] | None,
) -> Dict[str, torch.Tensor]:
def compute_losses(self, batch: TrainBatch) -> Dict[LossFunction, torch.Tensor]:
raise NotImplementedError

def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
if self.loss_functions is None:
raise ValueError("Loss function is not set")
losses = self.compute_losses(batch, self.loss_functions)
for key, loss in losses.items():
self.log(key, loss)
losses = self.compute_losses(batch)
total_loss = torch.tensor(0)
for loss_function, loss in losses.items():
self.log(loss_function.__class__.__name__, loss)
total_loss = total_loss + loss * self.loss_functions[loss_function]
loss = sum(losses.values(), torch.tensor(0))
self.log("loss", loss, prog_bar=True)
return loss
Expand Down
24 changes: 9 additions & 15 deletions lightning_ir/bi_encoder/module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Sequence
from typing import Dict, Mapping, Sequence

import torch

Expand Down Expand Up @@ -75,15 +75,9 @@ def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutpu
batch.doc_ids = doc_ids
return output

def compute_losses(
self,
batch: TrainBatch,
loss_functions: Sequence[LossFunction] | None,
) -> Dict[str, torch.Tensor]:
if loss_functions is None:
if self.loss_functions is None:
raise ValueError("Loss functions are not set")
loss_functions = self.loss_functions
def compute_losses(self, batch: TrainBatch) -> Dict[LossFunction, torch.Tensor]:
if self.loss_functions is None:
raise ValueError("Loss function is not set")
output = self.forward(batch)

scores = output.scores
Expand All @@ -97,20 +91,20 @@ def compute_losses(
num_queries = len(batch.queries)
scores = scores.view(num_queries, -1)
targets = batch.targets.view(*scores.shape, -1)
losses = {}
for loss_function in loss_functions:
losses: Dict[LossFunction, torch.Tensor] = {}
for loss_function in self.loss_functions:
if isinstance(loss_function, InBatchLossFunction):
pos_idcs, neg_idcs = loss_function.get_ib_idcs(*scores.shape)
ib_doc_embeddings = self.get_ib_doc_embeddings(doc_embeddings, pos_idcs, neg_idcs, num_queries)
ib_scores = self.model.score(query_embeddings, ib_doc_embeddings)
ib_scores = ib_scores.view(num_queries, -1)
losses[loss_function.__class__.__name__] = loss_function.compute_loss(ib_scores)
losses[loss_function] = loss_function.compute_loss(ib_scores)
elif isinstance(loss_function, EmbeddingLossFunction):
losses[loss_function.__class__.__name__] = loss_function.compute_loss(
losses[loss_function] = loss_function.compute_loss(
query_embeddings.embeddings, doc_embeddings.embeddings
)
elif isinstance(loss_function, ScoringLossFunction):
losses[loss_function.__class__.__name__] = loss_function.compute_loss(scores, targets)
losses[loss_function] = loss_function.compute_loss(scores, targets)
else:
raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
if self.config.sparsification is not None:
Expand Down
16 changes: 5 additions & 11 deletions lightning_ir/cross_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,9 @@ def forward(self, batch: RankBatch) -> CrossEncoderOutput:
output = self.model.forward(encoding["encoding"])
return output

def compute_losses(
self,
batch: TrainBatch,
loss_functions: Sequence[LossFunction] | None,
) -> Dict[str, torch.Tensor]:
if loss_functions is None:
if self.loss_functions is None:
raise ValueError("Loss functions are not set")
loss_functions = self.loss_functions
def compute_losses(self, batch: TrainBatch) -> Dict[LossFunction, torch.Tensor]:
if self.loss_functions is None:
raise ValueError("loss_functions must be set in the module")
output = self.forward(batch)
scores = output.scores
if scores is None or batch.targets is None:
Expand All @@ -50,8 +44,8 @@ def compute_losses(
targets = batch.targets.view(*scores.shape, -1)

losses = {}
for loss_function in loss_functions:
for loss_function in self.loss_functions:
if isinstance(loss_function, InBatchLossFunction):
raise NotImplementedError("InBatchLossFunction not implemented for cross-encoders")
losses[loss_function.__class__.__name__] = loss_function.compute_loss(scores, targets)
losses[loss_function] = loss_function.compute_loss(scores, targets)
return losses

0 comments on commit 6d3b06a

Please sign in to comment.