Skip to content

Commit

Permalink
add data documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 25, 2024
1 parent a4127af commit 4daa949
Show file tree
Hide file tree
Showing 5 changed files with 429 additions and 62 deletions.
6 changes: 6 additions & 0 deletions lightning_ir/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Lightning IR data module.
This module provides classes for handling data in Lightning IR, including data modules, datasets, and data samples.
"""

from .data import DocSample, IndexBatch, QuerySample, RankBatch, RankSample, SearchBatch, TrainBatch
from .datamodule import LightningIRDataModule
from .dataset import DocDataset, QueryDataset, RunDataset, TupleDataset
Expand Down
111 changes: 109 additions & 2 deletions lightning_ir/data/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
"""
Basic sample classes for Lightning IR.
This module defines the basic samples classes for Lightning IR. A sample is single entry in a dataset and can be grouped
into batches for processing.
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Sequence

import torch
from ir_datasets.formats.base import GenericDoc, GenericQuery


@dataclass
class RankSample:
"""A sample of ranking data containing a query, a ranked list of documents, and optionally targets and qrels.
:param query_id: Id of the query
:type query_id: str
:param query: Query text
:type query_id: str
:param doc_ids: List of document ids
:type doc_ids: Sequence[str]
:param docs: List of document texts
:type docs: Sequence[str]
:param targets: Optional list of target labels denoting the relevane of a document for the query
:type targets: torch.Tensor, optional
:param qrels: Optional list of dictionaries mapping document ids to relevance labels
"""

query_id: str
query: str
doc_ids: Sequence[str]
Expand All @@ -16,28 +39,75 @@ class RankSample:

@dataclass
class QuerySample:
"""A sample of query data containing a query and its id.
:param query_id: Id of the query
:type query_id: str
:param query: Query text
:type query_id: str
"""

query_id: str
query: str

@classmethod
def from_ir_dataset_sample(cls, sample):
def from_ir_dataset_sample(cls, sample: GenericQuery) -> "QuerySample":
"""Create a QuerySample from a an ir_datasets sample.
:param sample: ir_datasets sample
:type sample: GenericQuery
:return: Query sample
:rtype: QuerySample
"""
return cls(sample[0], sample[1])


@dataclass
class DocSample:
"""A sample of document data containing a document and its id.
:param doc_id: Id of the document
:type doc_id: str
:param doc: Document text
:type doc
"""

doc_id: str
doc: str

@classmethod
def from_ir_dataset_sample(cls, sample, text_fields: Sequence[str] | None = None):
def from_ir_dataset_sample(cls, sample: GenericDoc, text_fields: Sequence[str] | None = None) -> "DocSample":
"""Create a DocSample from an ir_datasets sample.
:param sample: ir_datasets sample
:type sample: GenericDoc
:param text_fields: Optional fields to parse the text. If None uses the samples ``default_text()``
defaults to None
:type text_fields: Sequence[str] | None, optional
:return: Doc sample
:rtype: DocSample
"""
if text_fields is not None:
return cls(sample[0], " ".join(getattr(sample, field) for field in text_fields))
return cls(sample[0], sample.default_text())


@dataclass
class RankBatch:
"""A batch of ranking data combining multiple :py:class:`.RankSample` instances
:param queries: List of query texts
:type queries: Sequence[str]
:param docs: List of list of document texts
:type docs: Sequence[Sequence[str]]
:param query_ids: Optional list of query ids
:type query_ids: Sequence[str], optional
:param doc_ids: Optional list of list of document ids
:type doc_ids: Sequence[Sequence[str]], optional
:param qrels: Optional list of dictionaries mapping document ids to relevance labels
:type qrels: List[Dict[str, Any]], optional
"""

queries: Sequence[str]
docs: Sequence[Sequence[str]]
query_ids: Sequence[str] | None = None
Expand All @@ -47,17 +117,54 @@ class RankBatch:

@dataclass
class TrainBatch(RankBatch):
"""A batch of ranking data that combines multiple :py:class:`.RankSample` instances
:param queries: List of query texts
:type queries: Sequence[str]
:param docs: List of list of document texts
:type docs: Sequence[Sequence[str]]
:param query_ids: Optional list of query ids
:type query_ids: Sequence[str], optional
:param doc_ids: Optional list of list of document ids
:type doc_ids: Sequence[Sequence[str]], optional
:param qrels: Optional list of dictionaries mapping document ids to relevance labels
:type qrels: List[Dict[str, Any]], optional
:param targets: Optional list of target labels denoting the relevane of a document for the query
:type targets: torch.Tensor, optional
"""

targets: torch.Tensor | None = None


@dataclass
class IndexBatch:
"""A batch of index that combines multiple :py:class`.DocSample` instances
:param doc_ids: List of document ids
:type doc_ids: Sequence[str]
:param docs: List of document texts
:type docs: Sequence[str]
"""

doc_ids: Sequence[str]
docs: Sequence[str]


@dataclass
class SearchBatch:
"""A batch of search data that combines multiple :py:class:`.QuerySample` instances. Optionaly includes document ids
and qrels.
:param query_ids: List of query ids
:type query_ids: Sequence[str]
:param queries: List of query texts
:type queries: Sequence[str]
:param doc_ids: Optional list of list of document ids
:type doc_ids: Sequence[Sequence[str]], optional
:param qrels: Optional list of dictionaries mapping document ids to relevance labels
:type qrels: List[Dict[str, Any]], optional
"""

query_ids: Sequence[str]
queries: Sequence[str]
doc_ids: Sequence[Sequence[str]] | None = None
Expand Down
71 changes: 63 additions & 8 deletions lightning_ir/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
DataModule for Lightning IR that handles batching and collation of data samples.
This module defines the LightningIRDataModule class that handles batching and collation of data samples for training and
inference in Lightning IR.
"""

from __future__ import annotations

from collections import defaultdict
Expand All @@ -17,20 +24,41 @@ def __init__(
train_dataset: RunDataset | TupleDataset | None = None,
train_batch_size: int | None = None,
shuffle_train: bool = True,
inference_batch_size: int | None = None,
inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None,
inference_batch_size: int | None = None,
num_workers: int = 0,
) -> None:
"""Initializes a new Lightning IR DataModule.
:param train_dataset: A training dataset, defaults to None
:type train_dataset: RunDataset | TupleDataset | None, optional
:param train_batch_size: Batch size to use for training, defaults to None
:type train_batch_size: int | None, optional
:param shuffle_train: Whether to shuffle the training data, defaults to True
:type shuffle_train: bool, optional
:param inference_datasets: List of datasets to use for inference (indexing, searching, and re-ranking),
defaults to None
:type inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None, optional
:param inference_batch_size: Batch size to use for inference, defaults to None
:type inference_batch_size: int | None, optional
:param num_workers: Number of workers for loading data in parallel, defaults to 0
:type num_workers: int, optional
"""
super().__init__()
self.num_workers = num_workers

self.train_dataset = train_dataset
self.train_batch_size = train_batch_size
self.shuffle_train = shuffle_train
self.inference_batch_size = inference_batch_size
self.train_dataset = train_dataset
self.inference_datasets = inference_datasets
self.inference_batch_size = inference_batch_size

if (self.train_batch_size is not None) != (self.train_dataset is not None):
raise ValueError("Both train_batch_size and train_dataset must be provided.")
if (self.inference_batch_size is not None) != (self.inference_datasets is not None):
raise ValueError("Both train_batch_size and train_dataset must be provided.")

def setup_inference(self, stage: Literal["validate", "test"]) -> None:
def _setup_inference(self, stage: Literal["validate", "test"]) -> None:
if self.inference_datasets is None:
return
for inference_dataset in self.inference_datasets:
Expand All @@ -48,39 +76,66 @@ def setup_inference(self, stage: Literal["validate", "test"]) -> None:
)

def setup(self, stage: Literal["fit", "validate", "test"]) -> None:
"""Sets up the data module for a given stage.
:param stage: Stage to set up the data module for
:type stage: Literal['fit', 'validate', 'test']
:raises ValueError: If the stage is `fit` and no training dataset is provided
"""
if stage == "fit":
if self.train_dataset is None:
raise ValueError("A training dataset and config must be provided.")
if stage == "fit":
stage = "validate"
self.setup_inference(stage)
self._setup_inference(stage)

def train_dataloader(self) -> DataLoader:
"""Returns a dataloader for training.
:raises ValueError: If no training dataset is found
:return: Dataloader for training
:rtype: DataLoader
"""
if self.train_dataset is None:
raise ValueError("No training dataset found.")
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
collate_fn=self._collate_fn,
shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train),
prefetch_factor=16 if self.num_workers > 0 else None,
)

def val_dataloader(self) -> List[DataLoader]:
"""Returns a list of dataloaders for validation.
:return: Dataloaders for validation
:rtype: List[DataLoader]
"""
return self.inference_dataloader()

def test_dataloader(self) -> List[DataLoader]:
"""Returns a list of dataloaders for testing.
:return: Dataloaders for testing
:rtype: List[DataLoader]
"""
return self.inference_dataloader()

def inference_dataloader(self) -> List[DataLoader]:
"""Returns a list of dataloaders for inference (testing or validation).
:return: Dataloaders for inference
:rtype: List[DataLoader]
"""
inference_datasets = self.inference_datasets or []
return [
DataLoader(
dataset,
batch_size=self.inference_batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
collate_fn=self._collate_fn,
prefetch_factor=16 if self.num_workers > 0 else None,
)
for dataset in inference_datasets
Expand Down Expand Up @@ -134,7 +189,7 @@ def _parse_batch(
return IndexBatch(**kwargs)
raise ValueError("Invalid dataset configuration.")

def collate_fn(
def _collate_fn(
self,
samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample,
) -> TrainBatch | RankBatch | IndexBatch | SearchBatch:
Expand Down
Loading

0 comments on commit 4daa949

Please sign in to comment.