diff --git a/README.md b/README.md index f07338c..5ced777 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,8 @@ For more details, see the [Usage](#usage) section. | Model Name | TREC DL 19 | TREC DL 20 | | ------------------------------------------------------------------- | ---------- | ---------- | -| [monoelectra-base](https://huggingface.co/webis/monoelectra-base) | 0.715 | 0.715 | -| [monoelectra-large](https://huggingface.co/webis/monoelectra-large) | 0.730 | 0.730 | +| [monoelectra-base](https://huggingface.co/webis/monoelectra-base) | 0.75 | 0.77 | +| [monoelectra-large](https://huggingface.co/webis/monoelectra-large) | 0.75 | 0.79 | | monoT5 (Coming soon) | -- | -- | ### Bi-encoders diff --git a/lightning_ir/base/model.py b/lightning_ir/base/model.py index 99b7eb8..45e0d46 100644 --- a/lightning_ir/base/model.py +++ b/lightning_ir/base/model.py @@ -7,9 +7,9 @@ CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, + BertModel, PretrainedConfig, PreTrainedModel, - BertModel, ) from transformers.modeling_outputs import ModelOutput diff --git a/lightning_ir/base/module.py b/lightning_ir/base/module.py index a98eda8..2079756 100644 --- a/lightning_ir/base/module.py +++ b/lightning_ir/base/module.py @@ -224,9 +224,9 @@ def validate_loss( # NOTE skip in-batch losses because they can use a lot of memory if isinstance(loss_function, InBatchLossFunction): continue - metrics[f"validation-{loss_function.__class__.__name__}"] = ( - loss_function.compute_loss(scores, targets).item() - ) + metrics[ + f"validation-{loss_function.__class__.__name__}" + ] = loss_function.compute_loss(scores, targets).item() return metrics def on_validation_epoch_end(self) -> None: diff --git a/lightning_ir/bi_encoder/model.py b/lightning_ir/bi_encoder/model.py index 49b816b..e44f796 100644 --- a/lightning_ir/bi_encoder/model.py +++ b/lightning_ir/bi_encoder/model.py @@ -1,8 +1,8 @@ +import warnings from dataclasses import dataclass -from string import punctuation from functools import wraps -import warnings -from typing import Literal, Sequence, Callable +from string import punctuation +from typing import Callable, Literal, Sequence import torch from transformers import BatchEncoding diff --git a/lightning_ir/lightning_utils/lr_schedulers.py b/lightning_ir/lightning_utils/lr_schedulers.py index ae35b46..5780f09 100644 --- a/lightning_ir/lightning_utils/lr_schedulers.py +++ b/lightning_ir/lightning_utils/lr_schedulers.py @@ -17,7 +17,8 @@ def __init__( self.num_training_steps = num_training_steps super().__init__(optimizer, self.lr_lambda, last_epoch, verbose) - def lr_lambda(self, current_step: int) -> float: ... + def lr_lambda(self, current_step: int) -> float: + ... class LinearLRSchedulerWithWarmup(WarmupLRScheduler): diff --git a/lightning_ir/lightning_utils/schedulers.py b/lightning_ir/lightning_utils/schedulers.py index 65a19a9..5b45653 100644 --- a/lightning_ir/lightning_utils/schedulers.py +++ b/lightning_ir/lightning_utils/schedulers.py @@ -7,7 +7,6 @@ class LambdaWarmupScheduler(Callback, ABC): - def __init__( self, keys: Sequence[str], @@ -23,7 +22,8 @@ def __init__( self.values: Dict[str, float] = {} @abstractmethod - def lr_lambda(self, current_step: int) -> float: ... + def lr_lambda(self, current_step: int) -> float: + ... def step(self, key: str, current_step: int) -> float: value = self.values[key] @@ -73,7 +73,6 @@ def lr_lambda(self, current_step: int) -> float: class ConstantSchedulerWithWarmup(LambdaWarmupScheduler): - def lr_lambda(self, current_step: int) -> float: if current_step < self.num_delay_steps: return 0.0 diff --git a/lightning_ir/loss/loss.py b/lightning_ir/loss/loss.py index 05ef78a..859cd4d 100644 --- a/lightning_ir/loss/loss.py +++ b/lightning_ir/loss/loss.py @@ -1,12 +1,13 @@ -from typing import Literal, Tuple from abc import ABC, abstractmethod +from typing import Literal, Tuple import torch class LossFunction(ABC): @abstractmethod - def compute_loss(self, *args, **kwargs) -> torch.Tensor: ... + def compute_loss(self, *args, **kwargs) -> torch.Tensor: + ... class ScoringLossFunction(LossFunction): @@ -15,7 +16,8 @@ def compute_loss( self, scores: torch.Tensor, targets: torch.Tensor, - ) -> torch.Tensor: ... + ) -> torch.Tensor: + ... def process_targets( self, scores: torch.Tensor, targets: torch.Tensor @@ -31,7 +33,8 @@ def compute_loss( self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, - ) -> torch.Tensor: ... + ) -> torch.Tensor: + ... class PairwiseLossFunction(ScoringLossFunction): @@ -346,7 +349,6 @@ def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None class L2Regularization(RegularizationLossFunction): - def compute_loss( self, query_embeddings: torch.Tensor, @@ -359,7 +361,6 @@ def compute_loss( class L1Regularization(RegularizationLossFunction): - def compute_loss( self, query_embeddings: torch.Tensor, diff --git a/lightning_ir/main.py b/lightning_ir/main.py index 9a4d8a4..e6d9a4a 100644 --- a/lightning_ir/main.py +++ b/lightning_ir/main.py @@ -1,7 +1,7 @@ import os import sys from pathlib import Path -from typing import Any, Dict, List, Set, Mapping +from typing import Any, Dict, List, Mapping, Set import torch from lightning import LightningDataModule, LightningModule, Trainer diff --git a/lightning_ir/models/col/model.py b/lightning_ir/models/col/model.py index bfcb108..4127290 100644 --- a/lightning_ir/models/col/model.py +++ b/lightning_ir/models/col/model.py @@ -5,7 +5,7 @@ from transformers import BertModel from transformers.modeling_utils import load_state_dict -from ...base import LightningIRModelClassFactory, LightningIRModel +from ...base import LightningIRModel, LightningIRModelClassFactory from ...bi_encoder.model import BiEncoderModel from .config import ColConfig diff --git a/lightning_ir/models/splade/model.py b/lightning_ir/models/splade/model.py index 928f7ad..25e27a2 100644 --- a/lightning_ir/models/splade/model.py +++ b/lightning_ir/models/splade/model.py @@ -73,7 +73,6 @@ def get_output_embeddings(self): def from_mlm_checkpoint( cls, model_name_or_path: str | Path, *args, **kwargs ) -> "SpladeModel": - config = AutoConfig.from_pretrained(model_name_or_path) BackboneModel = MODEL_MAPPING[config.__class__] cls = LightningIRModelClassFactory(BackboneModel, SpladeConfig) diff --git a/lightning_ir/retrieve/faiss_indexer.py b/lightning_ir/retrieve/faiss_indexer.py index c84a311..ce6046d 100644 --- a/lightning_ir/retrieve/faiss_indexer.py +++ b/lightning_ir/retrieve/faiss_indexer.py @@ -10,7 +10,6 @@ class FaissIndexer(Indexer): - INDEX_FACTORY: str def __init__( @@ -42,13 +41,16 @@ def __init__( self.to_gpu() @abstractmethod - def to_gpu(self) -> None: ... + def to_gpu(self) -> None: + ... @abstractmethod - def to_cpu(self) -> None: ... + def to_cpu(self) -> None: + ... @abstractmethod - def set_verbosity(self, verbose: bool | None = None) -> None: ... + def set_verbosity(self, verbose: bool | None = None) -> None: + ... def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: return embeddings @@ -84,7 +86,6 @@ def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: class FaissFlatIndexer(FaissIndexer): - INDEX_FACTORY = "Flat" def __init__( diff --git a/lightning_ir/retrieve/faiss_searcher.py b/lightning_ir/retrieve/faiss_searcher.py index 913c235..eacdbd3 100644 --- a/lightning_ir/retrieve/faiss_searcher.py +++ b/lightning_ir/retrieve/faiss_searcher.py @@ -81,7 +81,6 @@ def candidate_retrieval( def gather_imputation( self, candidate_doc_idcs: torch.Tensor, query_lengths: torch.Tensor ) -> Tuple[BiEncoderEmbedding, torch.Tensor, List[int]]: - # unique doc_idcs per query doc_idcs_per_query = [ list(sorted(set(idcs.reshape(-1).tolist()))) @@ -212,7 +211,6 @@ def intra_ranking_imputation( class FaissSearchConfig(SearchConfig): - search_class = FaissSearcher def __init__( diff --git a/lightning_ir/retrieve/searcher.py b/lightning_ir/retrieve/searcher.py index 86e0b22..c975420 100644 --- a/lightning_ir/retrieve/searcher.py +++ b/lightning_ir/retrieve/searcher.py @@ -41,12 +41,14 @@ def to_gpu(self) -> None: @property @abstractmethod - def num_embeddings(self) -> int: ... + def num_embeddings(self) -> int: + ... @abstractmethod def _search( self, query_embeddings: BiEncoderEmbedding - ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: ... + ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + ... def _filter_and_sort( self, diff --git a/lightning_ir/retrieve/sparse_indexer.py b/lightning_ir/retrieve/sparse_indexer.py index 913a1f4..9576271 100644 --- a/lightning_ir/retrieve/sparse_indexer.py +++ b/lightning_ir/retrieve/sparse_indexer.py @@ -1,5 +1,6 @@ import array from pathlib import Path + import torch from ..bi_encoder import BiEncoderConfig, BiEncoderOutput @@ -8,7 +9,6 @@ class SparseIndexer(Indexer): - def __init__( self, index_dir: Path, diff --git a/lightning_ir/retrieve/sparse_searcher.py b/lightning_ir/retrieve/sparse_searcher.py index a07ce52..a67a23d 100644 --- a/lightning_ir/retrieve/sparse_searcher.py +++ b/lightning_ir/retrieve/sparse_searcher.py @@ -13,7 +13,6 @@ class SparseIndex: - def __init__( self, index_dir: Path, similarity_function: Literal["dot", "cosine"] ) -> None: @@ -48,7 +47,6 @@ def to_gpu(self) -> None: class SparseSearcher(Searcher): - def __init__( self, index_dir: Path, @@ -118,5 +116,4 @@ def _search( class SparseSearchConfig(SearchConfig): - search_class = SparseSearcher diff --git a/tests/test_models/test_splade.py b/tests/test_models/test_splade.py index 98d037d..955c519 100644 --- a/tests/test_models/test_splade.py +++ b/tests/test_models/test_splade.py @@ -1,12 +1,11 @@ import torch -from splade.models.models_utils import get_model from omegaconf import DictConfig +from splade.models.models_utils import get_model from lightning_ir import SpladeModel def test_same_as_splade(): - query = "What is the capital of France?" documents = [ "Paris is the capital of France.", diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 0bbe731..654a416 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -1,14 +1,14 @@ from typing import Any -import pytest -from lightning import LightningModule, LightningDataModule, Trainer -from torch.utils.data import Dataset, DataLoader +import pytest import torch +from lightning import LightningDataModule, LightningModule, Trainer +from torch.utils.data import DataLoader, Dataset from lightning_ir.lightning_utils.schedulers import ( - LinearSchedulerWithWarmup, ConstantSchedulerWithWarmup, LambdaWarmupScheduler, + LinearSchedulerWithWarmup, ) @@ -18,7 +18,6 @@ def __init__(self) -> None: class DummyModule(LightningModule): - def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.tensor(0.0)) @@ -32,7 +31,6 @@ def configure_optimizers(self) -> torch.optim.Optimizer: class DummyDataset(Dataset): - def __len__(self) -> int: return 100 @@ -41,7 +39,6 @@ def __getitem__(self, index) -> Any: class DummyDataModule(LightningDataModule): - def __init__(self) -> None: super().__init__()