Skip to content

Commit

Permalink
add val and test dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Aug 29, 2024
1 parent 7ac83b0 commit 270e885
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lightning_ir/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ def setup_inference(self, stage: Literal["validate", "test"]) -> None:
raise ValueError(
"Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset."
)
if stage == "validate":
self.val_dataloader = self.inference_dataloader
elif stage == "test":
self.test_dataloader = self.inference_dataloader
else:
raise ValueError(f"Unknown stage {stage}")

def setup(self, stage: Literal["fit", "validate", "test"]) -> None:
if stage == "fit":
Expand All @@ -94,6 +88,12 @@ def train_dataloader(self) -> DataLoader:
prefetch_factor=16 if self.num_workers > 0 else None,
)

def val_dataloader(self) -> List[DataLoader]:
return self.inference_dataloader()

def test_dataloader(self) -> List[DataLoader]:
return self.inference_dataloader()

def inference_dataloader(self) -> List[DataLoader]:
inference_datasets = self.inference_datasets or []
return [
Expand Down

0 comments on commit 270e885

Please sign in to comment.