From 9dc8ad7cff153f0c94d71e3a247fb538f75ec603 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Wed, 31 Jul 2024 12:57:37 +0200 Subject: [PATCH] add initial xtr model --- lightning_ir/__init__.py | 6 +- lightning_ir/models/__init__.py | 3 +- lightning_ir/models/xtr/__init__.py | 4 ++ lightning_ir/models/xtr/config.py | 21 ++++++ lightning_ir/models/xtr/model.py | 107 ++++++++++++++++++++++++++++ tests/test_models/test_xtr.py | 56 +++++++++++++++ 6 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 lightning_ir/models/xtr/__init__.py create mode 100644 lightning_ir/models/xtr/config.py create mode 100644 lightning_ir/models/xtr/model.py create mode 100644 tests/test_models/test_xtr.py diff --git a/lightning_ir/__init__.py b/lightning_ir/__init__.py index 855bac4..16aab17 100644 --- a/lightning_ir/__init__.py +++ b/lightning_ir/__init__.py @@ -61,7 +61,7 @@ RankNet, SupervisedMarginMSE, ) -from .models import ColConfig, ColModel, SpladeConfig, SpladeModel +from .models import ColConfig, ColModel, SpladeConfig, SpladeModel, XTRConfig, XTRModel from .retrieve import ( FaissFlatIndexConfig, FaissFlatIndexer, @@ -89,6 +89,8 @@ AutoModel.register(ColConfig, ColModel) AutoConfig.register(SpladeConfig.model_type, SpladeConfig) AutoModel.register(SpladeConfig, SpladeModel) +AutoConfig.register(XTRConfig.model_type, XTRConfig) +AutoModel.register(XTRConfig, XTRModel) __version__ = "0.0.1" @@ -160,4 +162,6 @@ "TrainBatch", "TupleDataset", "WarmupLRScheduler", + "XTRConfig", + "XTRModel", ] diff --git a/lightning_ir/models/__init__.py b/lightning_ir/models/__init__.py index 55da680..82e53b6 100644 --- a/lightning_ir/models/__init__.py +++ b/lightning_ir/models/__init__.py @@ -1,4 +1,5 @@ from .col import ColConfig, ColModel from .splade import SpladeConfig, SpladeModel +from .xtr import XTRConfig, XTRModel -__all__ = ["ColConfig", "ColModel", "SpladeConfig", "SpladeModel"] +__all__ = ["ColConfig", "ColModel", "SpladeConfig", "SpladeModel", "XTRConfig", "XTRModel"] diff --git a/lightning_ir/models/xtr/__init__.py b/lightning_ir/models/xtr/__init__.py new file mode 100644 index 0000000..2193b9a --- /dev/null +++ b/lightning_ir/models/xtr/__init__.py @@ -0,0 +1,4 @@ +from .config import XTRConfig +from .model import XTRModel + +__all__ = ["XTRConfig", "XTRModel"] diff --git a/lightning_ir/models/xtr/config.py b/lightning_ir/models/xtr/config.py new file mode 100644 index 0000000..3eddfe3 --- /dev/null +++ b/lightning_ir/models/xtr/config.py @@ -0,0 +1,21 @@ +from typing import Literal + +from ..col import ColConfig + + +class XTRConfig(ColConfig): + model_type = "xtr" + + ADDED_ARGS = ColConfig.ADDED_ARGS.union({"token_retrieval_k", "fill_strategy", "normalization"}) + + def __init__( + self, + token_retrieval_k: int | None = None, + fill_strategy: Literal["zero", "min"] = "zero", + normalization: Literal["Z"] | None = "Z", + **kwargs + ) -> None: + super().__init__(**kwargs) + self.token_retrieval_k = token_retrieval_k + self.fill_strategy = fill_strategy + self.normalization = normalization diff --git a/lightning_ir/models/xtr/model.py b/lightning_ir/models/xtr/model.py new file mode 100644 index 0000000..20d9797 --- /dev/null +++ b/lightning_ir/models/xtr/model.py @@ -0,0 +1,107 @@ +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download +from transformers.modeling_utils import load_state_dict + +from ...base import LightningIRModelClassFactory +from ...bi_encoder.model import BiEncoderEmbedding, ScoringFunction +from ..col import ColModel +from .config import XTRConfig + + +class XTRScoringFunction(ScoringFunction): + def __init__(self, config: XTRConfig) -> None: + super().__init__(config) + self.config: XTRConfig + + def compute_similarity( + self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding + ) -> torch.Tensor: + similarity = super().compute_similarity(query_embeddings, doc_embeddings) + + if self.training and self.xtr_token_retrieval_k is not None: + pass + # TODO implement simulated token retrieval + + # if not torch.all(num_docs == num_docs[0]): + # raise ValueError("XTR token retrieval does not support variable number of documents.") + # query_embeddings = query_embeddings[:: num_docs[0]] + # doc_embeddings = doc_embeddings.view(1, 1, -1, doc_embeddings.shape[-1]) + # ib_similarity = super().compute_similarity( + # query_embeddings, + # doc_embeddings, + # query_scoring_mask[:: num_docs[0]], + # doc_scoring_mask.view(1, -1), + # num_docs, + # ) + # top_k_similarity = ib_similarity.topk(self.xtr_token_retrieval_k, dim=-1) + # cut_off_similarity = top_k_similarity.values[..., [-1]].repeat_interleave(num_docs, dim=0) + # if self.fill_strategy == "min": + # fill = cut_off_similarity.expand_as(similarity)[similarity < cut_off_similarity] + # elif self.fill_strategy == "zero": + # fill = 0 + # similarity[similarity < cut_off_similarity] = fill + return similarity + + # def aggregate( + # self, + # scores: torch.Tensor, + # mask: torch.Tensor, + # query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"], + # ) -> torch.Tensor: + # if self.training and self.normalization == "Z": + # # Z-normalization + # mask = mask & (scores != 0) + # return super().aggregate(scores, mask, query_aggregation_function) + + +class XTRModel(ColModel): + config_class = XTRConfig + + def __init__(self, config: XTRConfig, *args, **kwargs) -> None: + super().__init__(config) + self.scoring_function = XTRScoringFunction(config) + self.config: XTRConfig + + @classmethod + def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "XTRModel": + try: + hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin") + except Exception: + return super().from_pretrained(model_name_or_path, *args, **kwargs) + finally: + return cls.from_xtr_checkpoint(model_name_or_path) + + @classmethod + def from_xtr_checkpoint(cls, model_name_or_path: Path | str) -> "XTRModel": + from transformers import T5EncoderModel + + cls = LightningIRModelClassFactory(T5EncoderModel, XTRConfig) + config = cls.config_class.from_pretrained(model_name_or_path) + config.update( + { + "name_or_path": str(model_name_or_path), + "similarity_function": "dot", + "query_aggregation_function": "sum", + "query_expansion": False, + "doc_expansion": False, + "doc_pooling_strategy": None, + "doc_mask_scoring_tokens": None, + "normalize": True, + "sparsification": None, + "add_marker_tokens": False, + "embedding_dim": 128, + "projection": "linear_no_bias", + } + ) + state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="model.safetensors") + state_dict = load_state_dict(state_dict_path) + linear_state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin") + linear_state_dict = load_state_dict(linear_state_dict_path) + linear_state_dict["projection.weight"] = linear_state_dict.pop("linear.weight") + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + state_dict.update(linear_state_dict) + model = cls(config=config) + model.load_state_dict(state_dict) + return model diff --git a/tests/test_models/test_xtr.py b/tests/test_models/test_xtr.py new file mode 100644 index 0000000..00fddca --- /dev/null +++ b/tests/test_models/test_xtr.py @@ -0,0 +1,56 @@ +import torch +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, T5EncoderModel + +from lightning_ir.bi_encoder.tokenizer import BiEncoderTokenizer +from lightning_ir.models.xtr.model import XTRModel + + +class TestXTRModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.model = T5EncoderModel.from_pretrained("google/xtr-base-en") + self.linear = torch.nn.Linear(self.model.config.hidden_size, 128, bias=False) + linear_layer_path = hf_hub_download("google/xtr-base-en", filename="2_Dense/pytorch_model.bin") + state_dict = torch.load(linear_layer_path) + state_dict["weight"] = state_dict.pop("linear.weight") + self.linear.load_state_dict(state_dict) + + def forward(self, **kwargs): + encoded = self.model(**kwargs).last_hidden_state + embeddings = self.linear(encoded) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) + return embeddings + + +def test_same_as_xtr(): + model_name = "google/xtr-base-en" + orig_model = TestXTRModel().eval() + orig_tokenizer = AutoTokenizer.from_pretrained(model_name) + + query = "What is the capital of France?" + documents = [ + "Paris is the capital of France.", + "France is a country in Europe.", + "The Eiffel Tower is in Paris.", + ] + orig_query_encoding = orig_tokenizer(query, return_tensors="pt") + orig_doc_encoding = orig_tokenizer(documents, return_tensors="pt", padding=True, truncation=True) + + model = XTRModel.from_pretrained(model_name).eval() + tokenizer = BiEncoderTokenizer.from_pretrained(model_name, **model.config.to_dict()) + query_encoding = tokenizer.tokenize_query(query, return_tensors="pt") + doc_encoding = tokenizer.tokenize_doc(documents, return_tensors="pt", padding=True, truncation=True) + + with torch.no_grad(): + query_embedding = orig_model(**orig_query_encoding) + doc_embedding = orig_model(**orig_doc_encoding) + output = model.forward(query_encoding=query_encoding, doc_encoding=doc_encoding) + + assert torch.allclose(query_embedding, output.query_embeddings.embeddings, atol=1e-6) + assert torch.allclose( + doc_embedding[orig_doc_encoding.attention_mask.bool()], + output.doc_embeddings.embeddings[doc_encoding.attention_mask.bool()], + atol=1e-6, + )