Skip to content

Commit

Permalink
use lightningirtrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Aug 29, 2024
1 parent 5a59677 commit 7ac83b0
Showing 1 changed file with 7 additions and 21 deletions.
28 changes: 7 additions & 21 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,8 @@
import pandas as pd
import pytest
from _pytest.fixtures import SubRequest
from lightning import Trainer

from lightning_ir import (
BiEncoderModule,
LightningIRDataModule,
LightningIRModule,
RunDataset,
)
from lightning_ir import BiEncoderModule, LightningIRDataModule, LightningIRModule, LightningIRTrainer, RunDataset
from lightning_ir.lightning_utils.callbacks import IndexCallback, RankCallback
from lightning_ir.retrieve import (
FaissFlatIndexConfig,
Expand Down Expand Up @@ -56,7 +50,7 @@ def test_index_callback(
index_callback = IndexCallback(index_dir, index_config)
index_dir = index_dir / doc_datamodule.inference_datasets[0].docs_dataset_id

trainer = Trainer(
trainer = LightningIRTrainer(
# devices=devices,
logger=False,
enable_checkpointing=False,
Expand Down Expand Up @@ -96,7 +90,7 @@ def get_index(

index_callback = IndexCallback(index_dir, index_config)

trainer = Trainer(
trainer = LightningIRTrainer(
logger=False,
enable_checkpointing=False,
callbacks=[index_callback],
Expand Down Expand Up @@ -127,7 +121,7 @@ def test_search_callback(
save_dir = tmp_path / "runs"
search_callback = RankCallback(save_dir)

trainer = Trainer(
trainer = LightningIRTrainer(
logger=False,
enable_checkpointing=False,
callbacks=[search_callback],
Expand All @@ -148,20 +142,12 @@ def test_search_callback(
assert run_df["query_id"].nunique() == len(dataset)


def test_rerank_callback(
tmp_path: Path,
module: LightningIRModule,
inference_datasets: Sequence[RunDataset],
):
def test_rerank_callback(tmp_path: Path, module: LightningIRModule, inference_datasets: Sequence[RunDataset]):
datamodule = run_datamodule(module, inference_datasets)
save_dir = tmp_path / "runs"
rerank_callback = RankCallback(save_dir)
trainer = Trainer(
logger=False,
enable_checkpointing=False,
callbacks=[rerank_callback],
)
trainer.test(module, datamodule=datamodule)
trainer = LightningIRTrainer(logger=False, enable_checkpointing=False, callbacks=[rerank_callback])
trainer.re_rank(module, datamodule)

for dataloader in trainer.test_dataloaders:
dataset = dataloader.dataset
Expand Down

0 comments on commit 7ac83b0

Please sign in to comment.