Skip to content

Commit

Permalink
Merge branch 'webis-de:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMrSheldon authored Jul 17, 2024
2 parents dba6925 + d418b00 commit 6b59ce0
Show file tree
Hide file tree
Showing 17 changed files with 38 additions and 44 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ For more details, see the [Usage](#usage) section.

| Model Name | TREC DL 19 | TREC DL 20 |
| ------------------------------------------------------------------- | ---------- | ---------- |
| [monoelectra-base](https://huggingface.co/webis/monoelectra-base) | 0.715 | 0.715 |
| [monoelectra-large](https://huggingface.co/webis/monoelectra-large) | 0.730 | 0.730 |
| [monoelectra-base](https://huggingface.co/webis/monoelectra-base) | 0.75 | 0.77 |
| [monoelectra-large](https://huggingface.co/webis/monoelectra-large) | 0.75 | 0.79 |
| monoT5 (Coming soon) | -- | -- |

### Bi-encoders
Expand Down
2 changes: 1 addition & 1 deletion lightning_ir/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
CONFIG_MAPPING,
MODEL_MAPPING,
AutoConfig,
BertModel,
PretrainedConfig,
PreTrainedModel,
BertModel,
)
from transformers.modeling_outputs import ModelOutput

Expand Down
6 changes: 3 additions & 3 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ def validate_loss(
# NOTE skip in-batch losses because they can use a lot of memory
if isinstance(loss_function, InBatchLossFunction):
continue
metrics[f"validation-{loss_function.__class__.__name__}"] = (
loss_function.compute_loss(scores, targets).item()
)
metrics[
f"validation-{loss_function.__class__.__name__}"
] = loss_function.compute_loss(scores, targets).item()
return metrics

def on_validation_epoch_end(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions lightning_ir/bi_encoder/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import warnings
from dataclasses import dataclass
from string import punctuation
from functools import wraps
import warnings
from typing import Literal, Sequence, Callable
from string import punctuation
from typing import Callable, Literal, Sequence

import torch
from transformers import BatchEncoding
Expand Down
3 changes: 2 additions & 1 deletion lightning_ir/lightning_utils/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
self.num_training_steps = num_training_steps
super().__init__(optimizer, self.lr_lambda, last_epoch, verbose)

def lr_lambda(self, current_step: int) -> float: ...
def lr_lambda(self, current_step: int) -> float:
...


class LinearLRSchedulerWithWarmup(WarmupLRScheduler):
Expand Down
5 changes: 2 additions & 3 deletions lightning_ir/lightning_utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class LambdaWarmupScheduler(Callback, ABC):

def __init__(
self,
keys: Sequence[str],
Expand All @@ -23,7 +22,8 @@ def __init__(
self.values: Dict[str, float] = {}

@abstractmethod
def lr_lambda(self, current_step: int) -> float: ...
def lr_lambda(self, current_step: int) -> float:
...

def step(self, key: str, current_step: int) -> float:
value = self.values[key]
Expand Down Expand Up @@ -73,7 +73,6 @@ def lr_lambda(self, current_step: int) -> float:


class ConstantSchedulerWithWarmup(LambdaWarmupScheduler):

def lr_lambda(self, current_step: int) -> float:
if current_step < self.num_delay_steps:
return 0.0
Expand Down
13 changes: 7 additions & 6 deletions lightning_ir/loss/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Literal, Tuple
from abc import ABC, abstractmethod
from typing import Literal, Tuple

import torch


class LossFunction(ABC):
@abstractmethod
def compute_loss(self, *args, **kwargs) -> torch.Tensor: ...
def compute_loss(self, *args, **kwargs) -> torch.Tensor:
...


class ScoringLossFunction(LossFunction):
Expand All @@ -15,7 +16,8 @@ def compute_loss(
self,
scores: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor: ...
) -> torch.Tensor:
...

def process_targets(
self, scores: torch.Tensor, targets: torch.Tensor
Expand All @@ -31,7 +33,8 @@ def compute_loss(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
) -> torch.Tensor: ...
) -> torch.Tensor:
...


class PairwiseLossFunction(ScoringLossFunction):
Expand Down Expand Up @@ -346,7 +349,6 @@ def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None


class L2Regularization(RegularizationLossFunction):

def compute_loss(
self,
query_embeddings: torch.Tensor,
Expand All @@ -359,7 +361,6 @@ def compute_loss(


class L1Regularization(RegularizationLossFunction):

def compute_loss(
self,
query_embeddings: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion lightning_ir/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Set, Mapping
from typing import Any, Dict, List, Mapping, Set

import torch
from lightning import LightningDataModule, LightningModule, Trainer
Expand Down
2 changes: 1 addition & 1 deletion lightning_ir/models/col/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import BertModel
from transformers.modeling_utils import load_state_dict

from ...base import LightningIRModelClassFactory, LightningIRModel
from ...base import LightningIRModel, LightningIRModelClassFactory
from ...bi_encoder.model import BiEncoderModel
from .config import ColConfig

Expand Down
1 change: 0 additions & 1 deletion lightning_ir/models/splade/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def get_output_embeddings(self):
def from_mlm_checkpoint(
cls, model_name_or_path: str | Path, *args, **kwargs
) -> "SpladeModel":

config = AutoConfig.from_pretrained(model_name_or_path)
BackboneModel = MODEL_MAPPING[config.__class__]
cls = LightningIRModelClassFactory(BackboneModel, SpladeConfig)
Expand Down
11 changes: 6 additions & 5 deletions lightning_ir/retrieve/faiss_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


class FaissIndexer(Indexer):

INDEX_FACTORY: str

def __init__(
Expand Down Expand Up @@ -42,13 +41,16 @@ def __init__(
self.to_gpu()

@abstractmethod
def to_gpu(self) -> None: ...
def to_gpu(self) -> None:
...

@abstractmethod
def to_cpu(self) -> None: ...
def to_cpu(self) -> None:
...

@abstractmethod
def set_verbosity(self, verbose: bool | None = None) -> None: ...
def set_verbosity(self, verbose: bool | None = None) -> None:
...

def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
return embeddings
Expand Down Expand Up @@ -84,7 +86,6 @@ def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:


class FaissFlatIndexer(FaissIndexer):

INDEX_FACTORY = "Flat"

def __init__(
Expand Down
2 changes: 0 additions & 2 deletions lightning_ir/retrieve/faiss_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def candidate_retrieval(
def gather_imputation(
self, candidate_doc_idcs: torch.Tensor, query_lengths: torch.Tensor
) -> Tuple[BiEncoderEmbedding, torch.Tensor, List[int]]:

# unique doc_idcs per query
doc_idcs_per_query = [
list(sorted(set(idcs.reshape(-1).tolist())))
Expand Down Expand Up @@ -212,7 +211,6 @@ def intra_ranking_imputation(


class FaissSearchConfig(SearchConfig):

search_class = FaissSearcher

def __init__(
Expand Down
6 changes: 4 additions & 2 deletions lightning_ir/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def to_gpu(self) -> None:

@property
@abstractmethod
def num_embeddings(self) -> int: ...
def num_embeddings(self) -> int:
...

@abstractmethod
def _search(
self, query_embeddings: BiEncoderEmbedding
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: ...
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
...

def _filter_and_sort(
self,
Expand Down
2 changes: 1 addition & 1 deletion lightning_ir/retrieve/sparse_indexer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import array
from pathlib import Path

import torch

from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
Expand All @@ -8,7 +9,6 @@


class SparseIndexer(Indexer):

def __init__(
self,
index_dir: Path,
Expand Down
3 changes: 0 additions & 3 deletions lightning_ir/retrieve/sparse_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class SparseIndex:

def __init__(
self, index_dir: Path, similarity_function: Literal["dot", "cosine"]
) -> None:
Expand Down Expand Up @@ -48,7 +47,6 @@ def to_gpu(self) -> None:


class SparseSearcher(Searcher):

def __init__(
self,
index_dir: Path,
Expand Down Expand Up @@ -118,5 +116,4 @@ def _search(


class SparseSearchConfig(SearchConfig):

search_class = SparseSearcher
3 changes: 1 addition & 2 deletions tests/test_models/test_splade.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
from splade.models.models_utils import get_model
from omegaconf import DictConfig
from splade.models.models_utils import get_model

from lightning_ir import SpladeModel


def test_same_as_splade():

query = "What is the capital of France?"
documents = [
"Paris is the capital of France.",
Expand Down
11 changes: 4 additions & 7 deletions tests/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any
import pytest

from lightning import LightningModule, LightningDataModule, Trainer
from torch.utils.data import Dataset, DataLoader
import pytest
import torch
from lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

from lightning_ir.lightning_utils.schedulers import (
LinearSchedulerWithWarmup,
ConstantSchedulerWithWarmup,
LambdaWarmupScheduler,
LinearSchedulerWithWarmup,
)


Expand All @@ -18,7 +18,6 @@ def __init__(self) -> None:


class DummyModule(LightningModule):

def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.tensor(0.0))
Expand All @@ -32,7 +31,6 @@ def configure_optimizers(self) -> torch.optim.Optimizer:


class DummyDataset(Dataset):

def __len__(self) -> int:
return 100

Expand All @@ -41,7 +39,6 @@ def __getitem__(self, index) -> Any:


class DummyDataModule(LightningDataModule):

def __init__(self) -> None:
super().__init__()

Expand Down

0 comments on commit 6b59ce0

Please sign in to comment.