From 2b04e739b09a2442a70942c1c53726edf087826d Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Mon, 28 Oct 2024 10:51:33 +0100 Subject: [PATCH] remove tokenizer from datamodule --- tests/conftest.py | 8 ++------ tests/test_data.py | 43 ++++++++----------------------------------- tests/test_model.py | 1 + 3 files changed, 11 insertions(+), 41 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bae1874..dc8171d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,10 +160,8 @@ def cross_encoder_module(cross_encoder_config: CrossEncoderConfig, model_name_or @pytest.fixture() -def query_datamodule(model_name_or_path: str) -> LightningIRDataModule: +def query_datamodule() -> LightningIRDataModule: datamodule = LightningIRDataModule( - model_name_or_path=model_name_or_path, - config=BiEncoderConfig(), num_workers=0, inference_batch_size=2, inference_datasets=[QueryDataset("lightning-ir", num_queries=2)], @@ -173,10 +171,8 @@ def query_datamodule(model_name_or_path: str) -> LightningIRDataModule: @pytest.fixture() -def doc_datamodule(model_name_or_path: str) -> LightningIRDataModule: +def doc_datamodule() -> LightningIRDataModule: datamodule = LightningIRDataModule( - model_name_or_path=model_name_or_path, - config=BiEncoderConfig(), num_workers=0, inference_batch_size=2, inference_datasets=[DocDataset("lightning-ir")], diff --git a/tests/test_data.py b/tests/test_data.py index 77601c2..f1fbfc9 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -2,10 +2,7 @@ import pytest import torch -from _pytest.fixtures import SubRequest -from lightning_ir.bi_encoder.config import BiEncoderConfig -from lightning_ir.cross_encoder.model import CrossEncoderConfig from lightning_ir.data.data import IndexBatch, SearchBatch, TrainBatch from lightning_ir.data.datamodule import LightningIRDataModule from lightning_ir.data.dataset import RunDataset, TupleDataset @@ -13,12 +10,8 @@ from .conftest import RUNS_DIR -@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"]) -def rank_run_datamodule( - model_name_or_path: str, - inference_datasets: Sequence[RunDataset], - request: SubRequest, -) -> LightningIRDataModule: +@pytest.fixture() +def rank_run_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule: train_dataset = RunDataset( RUNS_DIR / "lightning-ir.tsv", depth=5, @@ -27,8 +20,6 @@ def rank_run_datamodule( targets="rank", ) datamodule = LightningIRDataModule( - model_name_or_path=model_name_or_path, - config=request.param, num_workers=0, train_batch_size=2, inference_batch_size=2, @@ -39,12 +30,8 @@ def rank_run_datamodule( return datamodule -@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"]) -def relevance_run_datamodule( - model_name_or_path: str, - inference_datasets: Sequence[RunDataset], - request: SubRequest, -) -> LightningIRDataModule: +@pytest.fixture() +def relevance_run_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule: train_dataset = RunDataset( RUNS_DIR / "lightning-ir.tsv", depth=5, @@ -53,8 +40,6 @@ def relevance_run_datamodule( targets="relevance", ) datamodule = LightningIRDataModule( - model_name_or_path=model_name_or_path, - config=request.param, num_workers=0, train_batch_size=2, inference_batch_size=2, @@ -65,12 +50,8 @@ def relevance_run_datamodule( return datamodule -@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"]) -def single_relevant_run_datamodule( - model_name_or_path: str, - inference_datasets: Sequence[RunDataset], - request: SubRequest, -) -> LightningIRDataModule: +@pytest.fixture() +def single_relevant_run_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule: train_dataset = RunDataset( RUNS_DIR / "lightning-ir.tsv", depth=5, @@ -79,8 +60,6 @@ def single_relevant_run_datamodule( targets="relevance", ) datamodule = LightningIRDataModule( - model_name_or_path=model_name_or_path, - config=request.param, num_workers=0, train_batch_size=2, inference_batch_size=2, @@ -91,16 +70,10 @@ def single_relevant_run_datamodule( return datamodule -@pytest.fixture(params=[BiEncoderConfig(), CrossEncoderConfig()], ids=["BiEncoder", "CrossEncoder"]) -def tuples_datamodule( - model_name_or_path: str, - inference_datasets: Sequence[RunDataset], - request: SubRequest, -) -> LightningIRDataModule: +@pytest.fixture() +def tuples_datamodule(inference_datasets: Sequence[RunDataset]) -> LightningIRDataModule: train_dataset = TupleDataset("lightning-ir", targets="order", num_docs=2) datamodule = LightningIRDataModule( - model_name_or_path=model_name_or_path, - config=request.param, num_workers=0, train_batch_size=2, inference_batch_size=2, diff --git a/tests/test_model.py b/tests/test_model.py index 6532255..aa07233 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -78,6 +78,7 @@ def test_seralize_deserialize(module: LightningIRModule, tmp_path: Path): "_commit_hash", "transformers_version", "model_type", + "_attn_implementation_autoset", ): continue assert getattr(new_model.config, key) == value