diff --git a/lightning_ir/base/module.py b/lightning_ir/base/module.py index 77aec0e..017b2b8 100644 --- a/lightning_ir/base/module.py +++ b/lightning_ir/base/module.py @@ -6,7 +6,7 @@ from lightning import LightningModule from transformers import BatchEncoding -from ..data import RankBatch, TrainBatch +from ..data import RankBatch, SearchBatch, TrainBatch from ..loss.loss import InBatchLossFunction, LossFunction from .config import LightningIRConfig from .model import LightningIRModel, LightningIROutput @@ -95,7 +95,7 @@ def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Se with torch.no_grad(): return self.forward(batch) - def forward(self, batch: TrainBatch | RankBatch) -> LightningIROutput: + def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: """Handles the forward pass of the model. :param batch: Batch of training or ranking data @@ -164,7 +164,7 @@ def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor: return total_loss def validation_step( - self, batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0 + self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 ) -> LightningIROutput: """Handles the validation step for the model. diff --git a/lightning_ir/cross_encoder/module.py b/lightning_ir/cross_encoder/module.py index e38cc2e..11686e2 100644 --- a/lightning_ir/cross_encoder/module.py +++ b/lightning_ir/cross_encoder/module.py @@ -3,7 +3,7 @@ import torch from ..base.module import LightningIRModule -from ..data import RankBatch, TrainBatch +from ..data import RankBatch, SearchBatch, TrainBatch from ..loss.loss import InBatchLossFunction, LossFunction from .config import CrossEncoderConfig from .model import CrossEncoderModel, CrossEncoderOutput @@ -24,7 +24,9 @@ def __init__( self.config: CrossEncoderConfig self.tokenizer: CrossEncoderTokenizer - def forward(self, batch: RankBatch) -> CrossEncoderOutput: + def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput: + if isinstance(batch, SearchBatch): + raise NotImplementedError("Searching is not available for cross-encoders") queries = batch.queries docs = [d for docs in batch.docs for d in docs] num_docs = [len(docs) for docs in batch.docs]