Skip to content

Commit

Permalink
add search batch to signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Aug 30, 2024
1 parent 18fe9ab commit 169f45d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 3 additions & 3 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lightning import LightningModule
from transformers import BatchEncoding

from ..data import RankBatch, TrainBatch
from ..data import RankBatch, SearchBatch, TrainBatch
from ..loss.loss import InBatchLossFunction, LossFunction
from .config import LightningIRConfig
from .model import LightningIRModel, LightningIROutput
Expand Down Expand Up @@ -95,7 +95,7 @@ def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Se
with torch.no_grad():
return self.forward(batch)

def forward(self, batch: TrainBatch | RankBatch) -> LightningIROutput:
def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput:
"""Handles the forward pass of the model.
:param batch: Batch of training or ranking data
Expand Down Expand Up @@ -164,7 +164,7 @@ def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
return total_loss

def validation_step(
self, batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0
self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0
) -> LightningIROutput:
"""Handles the validation step for the model.
Expand Down
6 changes: 4 additions & 2 deletions lightning_ir/cross_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from ..base.module import LightningIRModule
from ..data import RankBatch, TrainBatch
from ..data import RankBatch, SearchBatch, TrainBatch
from ..loss.loss import InBatchLossFunction, LossFunction
from .config import CrossEncoderConfig
from .model import CrossEncoderModel, CrossEncoderOutput
Expand All @@ -24,7 +24,9 @@ def __init__(
self.config: CrossEncoderConfig
self.tokenizer: CrossEncoderTokenizer

def forward(self, batch: RankBatch) -> CrossEncoderOutput:
def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput:
if isinstance(batch, SearchBatch):
raise NotImplementedError("Searching is not available for cross-encoders")
queries = batch.queries
docs = [d for docs in batch.docs for d in docs]
num_docs = [len(docs) for docs in batch.docs]
Expand Down

0 comments on commit 169f45d

Please sign in to comment.