From 7ac83b09aa56de11242b55acb15c11da5fb1bcd6 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Thu, 29 Aug 2024 16:07:55 +0200 Subject: [PATCH] use lightningirtrainer --- tests/test_callbacks.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 44a0a76..af52fc4 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -4,14 +4,8 @@ import pandas as pd import pytest from _pytest.fixtures import SubRequest -from lightning import Trainer -from lightning_ir import ( - BiEncoderModule, - LightningIRDataModule, - LightningIRModule, - RunDataset, -) +from lightning_ir import BiEncoderModule, LightningIRDataModule, LightningIRModule, LightningIRTrainer, RunDataset from lightning_ir.lightning_utils.callbacks import IndexCallback, RankCallback from lightning_ir.retrieve import ( FaissFlatIndexConfig, @@ -56,7 +50,7 @@ def test_index_callback( index_callback = IndexCallback(index_dir, index_config) index_dir = index_dir / doc_datamodule.inference_datasets[0].docs_dataset_id - trainer = Trainer( + trainer = LightningIRTrainer( # devices=devices, logger=False, enable_checkpointing=False, @@ -96,7 +90,7 @@ def get_index( index_callback = IndexCallback(index_dir, index_config) - trainer = Trainer( + trainer = LightningIRTrainer( logger=False, enable_checkpointing=False, callbacks=[index_callback], @@ -127,7 +121,7 @@ def test_search_callback( save_dir = tmp_path / "runs" search_callback = RankCallback(save_dir) - trainer = Trainer( + trainer = LightningIRTrainer( logger=False, enable_checkpointing=False, callbacks=[search_callback], @@ -148,20 +142,12 @@ def test_search_callback( assert run_df["query_id"].nunique() == len(dataset) -def test_rerank_callback( - tmp_path: Path, - module: LightningIRModule, - inference_datasets: Sequence[RunDataset], -): +def test_rerank_callback(tmp_path: Path, module: LightningIRModule, inference_datasets: Sequence[RunDataset]): datamodule = run_datamodule(module, inference_datasets) save_dir = tmp_path / "runs" rerank_callback = RankCallback(save_dir) - trainer = Trainer( - logger=False, - enable_checkpointing=False, - callbacks=[rerank_callback], - ) - trainer.test(module, datamodule=datamodule) + trainer = LightningIRTrainer(logger=False, enable_checkpointing=False, callbacks=[rerank_callback]) + trainer.re_rank(module, datamodule) for dataloader in trainer.test_dataloaders: dataset = dataloader.dataset