From 270e885676acecd11b0df0829c5bab1473390026 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Thu, 29 Aug 2024 16:08:09 +0200 Subject: [PATCH] add val and test dataloaders --- lightning_ir/data/datamodule.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightning_ir/data/datamodule.py b/lightning_ir/data/datamodule.py index 36f3b76..f3f56b2 100644 --- a/lightning_ir/data/datamodule.py +++ b/lightning_ir/data/datamodule.py @@ -67,12 +67,6 @@ def setup_inference(self, stage: Literal["validate", "test"]) -> None: raise ValueError( "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset." ) - if stage == "validate": - self.val_dataloader = self.inference_dataloader - elif stage == "test": - self.test_dataloader = self.inference_dataloader - else: - raise ValueError(f"Unknown stage {stage}") def setup(self, stage: Literal["fit", "validate", "test"]) -> None: if stage == "fit": @@ -94,6 +88,12 @@ def train_dataloader(self) -> DataLoader: prefetch_factor=16 if self.num_workers > 0 else None, ) + def val_dataloader(self) -> List[DataLoader]: + return self.inference_dataloader() + + def test_dataloader(self) -> List[DataLoader]: + return self.inference_dataloader() + def inference_dataloader(self) -> List[DataLoader]: inference_datasets = self.inference_datasets or [] return [