Skip to content

Commit

Permalink
remove tokenizer from datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Oct 28, 2024
1 parent dc9740f commit 2b04e73
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 41 deletions.
8 changes: 2 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,8 @@ def cross_encoder_module(cross_encoder_config: CrossEncoderConfig, model_name_or


@pytest.fixture()
def query_datamodule(model_name_or_path: str) -> LightningIRDataModule:
def query_datamodule() -> LightningIRDataModule:
datamodule = LightningIRDataModule(
model_name_or_path=model_name_or_path,
config=BiEncoderConfig(),
num_workers=0,
inference_batch_size=2,
inference_datasets=[QueryDataset("lightning-ir", num_queries=2)],
Expand All @@ -173,10 +171,8 @@ def query_datamodule(model_name_or_path: str) -> LightningIRDataModule:


@pytest.fixture()
def doc_datamodule(model_name_or_path: str) -> LightningIRDataModule:
def doc_datamodule() -> LightningIRDataModule:
datamodule = LightningIRDataModule(
model_name_or_path=model_name_or_path,
config=BiEncoderConfig(),
num_workers=0,
inference_batch_size=2,
inference_datasets=[DocDataset("lightning-ir")],
Expand Down
43 changes: 8 additions & 35 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,16 @@

import pytest
import torch
from _pytest.fixtures import SubRequest

from lightning_ir.bi_encoder.config import BiEncoderConfig
from lightning_ir.cross_encoder.model import CrossEncoderConfig
from lightning_ir.data.data import IndexBatch, SearchBatch, TrainBatch
from lightning_ir.data.datamodule import LightningIRDataModule
from lightning_ir.data.dataset import RunDataset, TupleDataset

from .conftest import RUNS_DIR


@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"])
def rank_run_datamodule(
model_name_or_path: str,
inference_datasets: Sequence[RunDataset],
request: SubRequest,
) -> LightningIRDataModule:
@pytest.fixture()
def rank_run_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule:
train_dataset = RunDataset(
RUNS_DIR / "lightning-ir.tsv",
depth=5,
Expand All @@ -27,8 +20,6 @@ def rank_run_datamodule(
targets="rank",
)
datamodule = LightningIRDataModule(
model_name_or_path=model_name_or_path,
config=request.param,
num_workers=0,
train_batch_size=2,
inference_batch_size=2,
Expand All @@ -39,12 +30,8 @@ def rank_run_datamodule(
return datamodule


@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"])
def relevance_run_datamodule(
model_name_or_path: str,
inference_datasets: Sequence[RunDataset],
request: SubRequest,
) -> LightningIRDataModule:
@pytest.fixture()
def relevance_run_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule:
train_dataset = RunDataset(
RUNS_DIR / "lightning-ir.tsv",
depth=5,
Expand All @@ -53,8 +40,6 @@ def relevance_run_datamodule(
targets="relevance",
)
datamodule = LightningIRDataModule(
model_name_or_path=model_name_or_path,
config=request.param,
num_workers=0,
train_batch_size=2,
inference_batch_size=2,
Expand All @@ -65,12 +50,8 @@ def relevance_run_datamodule(
return datamodule


@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"])
def single_relevant_run_datamodule(
model_name_or_path: str,
inference_datasets: Sequence[RunDataset],
request: SubRequest,
) -> LightningIRDataModule:
@pytest.fixture()
def single_relevant_run_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule:
train_dataset = RunDataset(
RUNS_DIR / "lightning-ir.tsv",
depth=5,
Expand All @@ -79,8 +60,6 @@ def single_relevant_run_datamodule(
targets="relevance",
)
datamodule = LightningIRDataModule(
model_name_or_path=model_name_or_path,
config=request.param,
num_workers=0,
train_batch_size=2,
inference_batch_size=2,
Expand All @@ -91,16 +70,10 @@ def single_relevant_run_datamodule(
return datamodule


@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"])
def tuples_datamodule(
model_name_or_path: str,
inference_datasets: Sequence[RunDataset],
request: SubRequest,
) -> LightningIRDataModule:
@pytest.fixture()
def tuples_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule:
train_dataset = TupleDataset("lightning-ir", targets="order", num_docs=2)
datamodule = LightningIRDataModule(
model_name_or_path=model_name_or_path,
config=request.param,
num_workers=0,
train_batch_size=2,
inference_batch_size=2,
Expand Down
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_seralize_deserialize(module: LightningIRModule, tmp_path: Path):
"_commit_hash",
"transformers_version",
"model_type",
"_attn_implementation_autoset",
):
continue
assert getattr(new_model.config, key) == value
Expand Down

0 comments on commit 2b04e73

Please sign in to comment.