Skip to content

Commit

Permalink
add initial xtr model
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 31, 2024
1 parent 0e37271 commit 9dc8ad7
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 2 deletions.
6 changes: 5 additions & 1 deletion lightning_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
RankNet,
SupervisedMarginMSE,
)
from .models import ColConfig, ColModel, SpladeConfig, SpladeModel
from .models import ColConfig, ColModel, SpladeConfig, SpladeModel, XTRConfig, XTRModel
from .retrieve import (
FaissFlatIndexConfig,
FaissFlatIndexer,
Expand Down Expand Up @@ -89,6 +89,8 @@
AutoModel.register(ColConfig, ColModel)
AutoConfig.register(SpladeConfig.model_type, SpladeConfig)
AutoModel.register(SpladeConfig, SpladeModel)
AutoConfig.register(XTRConfig.model_type, XTRConfig)
AutoModel.register(XTRConfig, XTRModel)

__version__ = "0.0.1"

Expand Down Expand Up @@ -160,4 +162,6 @@
"TrainBatch",
"TupleDataset",
"WarmupLRScheduler",
"XTRConfig",
"XTRModel",
]
3 changes: 2 additions & 1 deletion lightning_ir/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .col import ColConfig, ColModel
from .splade import SpladeConfig, SpladeModel
from .xtr import XTRConfig, XTRModel

__all__ = ["ColConfig", "ColModel", "SpladeConfig", "SpladeModel"]
__all__ = ["ColConfig", "ColModel", "SpladeConfig", "SpladeModel", "XTRConfig", "XTRModel"]
4 changes: 4 additions & 0 deletions lightning_ir/models/xtr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .config import XTRConfig
from .model import XTRModel

__all__ = ["XTRConfig", "XTRModel"]
21 changes: 21 additions & 0 deletions lightning_ir/models/xtr/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Literal

from ..col import ColConfig


class XTRConfig(ColConfig):
model_type = "xtr"

ADDED_ARGS = ColConfig.ADDED_ARGS.union({"token_retrieval_k", "fill_strategy", "normalization"})

def __init__(
self,
token_retrieval_k: int | None = None,
fill_strategy: Literal["zero", "min"] = "zero",
normalization: Literal["Z"] | None = "Z",
**kwargs
) -> None:
super().__init__(**kwargs)
self.token_retrieval_k = token_retrieval_k
self.fill_strategy = fill_strategy
self.normalization = normalization
107 changes: 107 additions & 0 deletions lightning_ir/models/xtr/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download
from transformers.modeling_utils import load_state_dict

from ...base import LightningIRModelClassFactory
from ...bi_encoder.model import BiEncoderEmbedding, ScoringFunction
from ..col import ColModel
from .config import XTRConfig


class XTRScoringFunction(ScoringFunction):
def __init__(self, config: XTRConfig) -> None:
super().__init__(config)
self.config: XTRConfig

def compute_similarity(
self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding
) -> torch.Tensor:
similarity = super().compute_similarity(query_embeddings, doc_embeddings)

if self.training and self.xtr_token_retrieval_k is not None:
pass
# TODO implement simulated token retrieval

# if not torch.all(num_docs == num_docs[0]):
# raise ValueError("XTR token retrieval does not support variable number of documents.")
# query_embeddings = query_embeddings[:: num_docs[0]]
# doc_embeddings = doc_embeddings.view(1, 1, -1, doc_embeddings.shape[-1])
# ib_similarity = super().compute_similarity(
# query_embeddings,
# doc_embeddings,
# query_scoring_mask[:: num_docs[0]],
# doc_scoring_mask.view(1, -1),
# num_docs,
# )
# top_k_similarity = ib_similarity.topk(self.xtr_token_retrieval_k, dim=-1)
# cut_off_similarity = top_k_similarity.values[..., [-1]].repeat_interleave(num_docs, dim=0)
# if self.fill_strategy == "min":
# fill = cut_off_similarity.expand_as(similarity)[similarity < cut_off_similarity]
# elif self.fill_strategy == "zero":
# fill = 0
# similarity[similarity < cut_off_similarity] = fill
return similarity

# def aggregate(
# self,
# scores: torch.Tensor,
# mask: torch.Tensor,
# query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"],
# ) -> torch.Tensor:
# if self.training and self.normalization == "Z":
# # Z-normalization
# mask = mask & (scores != 0)
# return super().aggregate(scores, mask, query_aggregation_function)


class XTRModel(ColModel):
config_class = XTRConfig

def __init__(self, config: XTRConfig, *args, **kwargs) -> None:
super().__init__(config)
self.scoring_function = XTRScoringFunction(config)
self.config: XTRConfig

@classmethod
def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "XTRModel":
try:
hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin")
except Exception:
return super().from_pretrained(model_name_or_path, *args, **kwargs)
finally:
return cls.from_xtr_checkpoint(model_name_or_path)

@classmethod
def from_xtr_checkpoint(cls, model_name_or_path: Path | str) -> "XTRModel":
from transformers import T5EncoderModel

cls = LightningIRModelClassFactory(T5EncoderModel, XTRConfig)
config = cls.config_class.from_pretrained(model_name_or_path)
config.update(
{
"name_or_path": str(model_name_or_path),
"similarity_function": "dot",
"query_aggregation_function": "sum",
"query_expansion": False,
"doc_expansion": False,
"doc_pooling_strategy": None,
"doc_mask_scoring_tokens": None,
"normalize": True,
"sparsification": None,
"add_marker_tokens": False,
"embedding_dim": 128,
"projection": "linear_no_bias",
}
)
state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="model.safetensors")
state_dict = load_state_dict(state_dict_path)
linear_state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin")
linear_state_dict = load_state_dict(linear_state_dict_path)
linear_state_dict["projection.weight"] = linear_state_dict.pop("linear.weight")
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
state_dict.update(linear_state_dict)
model = cls(config=config)
model.load_state_dict(state_dict)
return model
56 changes: 56 additions & 0 deletions tests/test_models/test_xtr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, T5EncoderModel

from lightning_ir.bi_encoder.tokenizer import BiEncoderTokenizer
from lightning_ir.models.xtr.model import XTRModel


class TestXTRModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.model = T5EncoderModel.from_pretrained("google/xtr-base-en")
self.linear = torch.nn.Linear(self.model.config.hidden_size, 128, bias=False)
linear_layer_path = hf_hub_download("google/xtr-base-en", filename="2_Dense/pytorch_model.bin")
state_dict = torch.load(linear_layer_path)
state_dict["weight"] = state_dict.pop("linear.weight")
self.linear.load_state_dict(state_dict)

def forward(self, **kwargs):
encoded = self.model(**kwargs).last_hidden_state
embeddings = self.linear(encoded)
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
return embeddings


def test_same_as_xtr():
model_name = "google/xtr-base-en"
orig_model = TestXTRModel().eval()
orig_tokenizer = AutoTokenizer.from_pretrained(model_name)

query = "What is the capital of France?"
documents = [
"Paris is the capital of France.",
"France is a country in Europe.",
"The Eiffel Tower is in Paris.",
]
orig_query_encoding = orig_tokenizer(query, return_tensors="pt")
orig_doc_encoding = orig_tokenizer(documents, return_tensors="pt", padding=True, truncation=True)

model = XTRModel.from_pretrained(model_name).eval()
tokenizer = BiEncoderTokenizer.from_pretrained(model_name, **model.config.to_dict())
query_encoding = tokenizer.tokenize_query(query, return_tensors="pt")
doc_encoding = tokenizer.tokenize_doc(documents, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
query_embedding = orig_model(**orig_query_encoding)
doc_embedding = orig_model(**orig_doc_encoding)
output = model.forward(query_encoding=query_encoding, doc_encoding=doc_encoding)

assert torch.allclose(query_embedding, output.query_embeddings.embeddings, atol=1e-6)
assert torch.allclose(
doc_embedding[orig_doc_encoding.attention_mask.bool()],
output.doc_embeddings.embeddings[doc_encoding.attention_mask.bool()],
atol=1e-6,
)

0 comments on commit 9dc8ad7

Please sign in to comment.