-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
195 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .col import ColConfig, ColModel | ||
from .splade import SpladeConfig, SpladeModel | ||
from .xtr import XTRConfig, XTRModel | ||
|
||
__all__ = ["ColConfig", "ColModel", "SpladeConfig", "SpladeModel"] | ||
__all__ = ["ColConfig", "ColModel", "SpladeConfig", "SpladeModel", "XTRConfig", "XTRModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .config import XTRConfig | ||
from .model import XTRModel | ||
|
||
__all__ = ["XTRConfig", "XTRModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Literal | ||
|
||
from ..col import ColConfig | ||
|
||
|
||
class XTRConfig(ColConfig): | ||
model_type = "xtr" | ||
|
||
ADDED_ARGS = ColConfig.ADDED_ARGS.union({"token_retrieval_k", "fill_strategy", "normalization"}) | ||
|
||
def __init__( | ||
self, | ||
token_retrieval_k: int | None = None, | ||
fill_strategy: Literal["zero", "min"] = "zero", | ||
normalization: Literal["Z"] | None = "Z", | ||
**kwargs | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.token_retrieval_k = token_retrieval_k | ||
self.fill_strategy = fill_strategy | ||
self.normalization = normalization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
from huggingface_hub import hf_hub_download | ||
from transformers.modeling_utils import load_state_dict | ||
|
||
from ...base import LightningIRModelClassFactory | ||
from ...bi_encoder.model import BiEncoderEmbedding, ScoringFunction | ||
from ..col import ColModel | ||
from .config import XTRConfig | ||
|
||
|
||
class XTRScoringFunction(ScoringFunction): | ||
def __init__(self, config: XTRConfig) -> None: | ||
super().__init__(config) | ||
self.config: XTRConfig | ||
|
||
def compute_similarity( | ||
self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding | ||
) -> torch.Tensor: | ||
similarity = super().compute_similarity(query_embeddings, doc_embeddings) | ||
|
||
if self.training and self.xtr_token_retrieval_k is not None: | ||
pass | ||
# TODO implement simulated token retrieval | ||
|
||
# if not torch.all(num_docs == num_docs[0]): | ||
# raise ValueError("XTR token retrieval does not support variable number of documents.") | ||
# query_embeddings = query_embeddings[:: num_docs[0]] | ||
# doc_embeddings = doc_embeddings.view(1, 1, -1, doc_embeddings.shape[-1]) | ||
# ib_similarity = super().compute_similarity( | ||
# query_embeddings, | ||
# doc_embeddings, | ||
# query_scoring_mask[:: num_docs[0]], | ||
# doc_scoring_mask.view(1, -1), | ||
# num_docs, | ||
# ) | ||
# top_k_similarity = ib_similarity.topk(self.xtr_token_retrieval_k, dim=-1) | ||
# cut_off_similarity = top_k_similarity.values[..., [-1]].repeat_interleave(num_docs, dim=0) | ||
# if self.fill_strategy == "min": | ||
# fill = cut_off_similarity.expand_as(similarity)[similarity < cut_off_similarity] | ||
# elif self.fill_strategy == "zero": | ||
# fill = 0 | ||
# similarity[similarity < cut_off_similarity] = fill | ||
return similarity | ||
|
||
# def aggregate( | ||
# self, | ||
# scores: torch.Tensor, | ||
# mask: torch.Tensor, | ||
# query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"], | ||
# ) -> torch.Tensor: | ||
# if self.training and self.normalization == "Z": | ||
# # Z-normalization | ||
# mask = mask & (scores != 0) | ||
# return super().aggregate(scores, mask, query_aggregation_function) | ||
|
||
|
||
class XTRModel(ColModel): | ||
config_class = XTRConfig | ||
|
||
def __init__(self, config: XTRConfig, *args, **kwargs) -> None: | ||
super().__init__(config) | ||
self.scoring_function = XTRScoringFunction(config) | ||
self.config: XTRConfig | ||
|
||
@classmethod | ||
def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "XTRModel": | ||
try: | ||
hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin") | ||
except Exception: | ||
return super().from_pretrained(model_name_or_path, *args, **kwargs) | ||
finally: | ||
return cls.from_xtr_checkpoint(model_name_or_path) | ||
|
||
@classmethod | ||
def from_xtr_checkpoint(cls, model_name_or_path: Path | str) -> "XTRModel": | ||
from transformers import T5EncoderModel | ||
|
||
cls = LightningIRModelClassFactory(T5EncoderModel, XTRConfig) | ||
config = cls.config_class.from_pretrained(model_name_or_path) | ||
config.update( | ||
{ | ||
"name_or_path": str(model_name_or_path), | ||
"similarity_function": "dot", | ||
"query_aggregation_function": "sum", | ||
"query_expansion": False, | ||
"doc_expansion": False, | ||
"doc_pooling_strategy": None, | ||
"doc_mask_scoring_tokens": None, | ||
"normalize": True, | ||
"sparsification": None, | ||
"add_marker_tokens": False, | ||
"embedding_dim": 128, | ||
"projection": "linear_no_bias", | ||
} | ||
) | ||
state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="model.safetensors") | ||
state_dict = load_state_dict(state_dict_path) | ||
linear_state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin") | ||
linear_state_dict = load_state_dict(linear_state_dict_path) | ||
linear_state_dict["projection.weight"] = linear_state_dict.pop("linear.weight") | ||
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] | ||
state_dict.update(linear_state_dict) | ||
model = cls(config=config) | ||
model.load_state_dict(state_dict) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
from huggingface_hub import hf_hub_download | ||
from transformers import AutoTokenizer, T5EncoderModel | ||
|
||
from lightning_ir.bi_encoder.tokenizer import BiEncoderTokenizer | ||
from lightning_ir.models.xtr.model import XTRModel | ||
|
||
|
||
class TestXTRModel(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.model = T5EncoderModel.from_pretrained("google/xtr-base-en") | ||
self.linear = torch.nn.Linear(self.model.config.hidden_size, 128, bias=False) | ||
linear_layer_path = hf_hub_download("google/xtr-base-en", filename="2_Dense/pytorch_model.bin") | ||
state_dict = torch.load(linear_layer_path) | ||
state_dict["weight"] = state_dict.pop("linear.weight") | ||
self.linear.load_state_dict(state_dict) | ||
|
||
def forward(self, **kwargs): | ||
encoded = self.model(**kwargs).last_hidden_state | ||
embeddings = self.linear(encoded) | ||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) | ||
return embeddings | ||
|
||
|
||
def test_same_as_xtr(): | ||
model_name = "google/xtr-base-en" | ||
orig_model = TestXTRModel().eval() | ||
orig_tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
query = "What is the capital of France?" | ||
documents = [ | ||
"Paris is the capital of France.", | ||
"France is a country in Europe.", | ||
"The Eiffel Tower is in Paris.", | ||
] | ||
orig_query_encoding = orig_tokenizer(query, return_tensors="pt") | ||
orig_doc_encoding = orig_tokenizer(documents, return_tensors="pt", padding=True, truncation=True) | ||
|
||
model = XTRModel.from_pretrained(model_name).eval() | ||
tokenizer = BiEncoderTokenizer.from_pretrained(model_name, **model.config.to_dict()) | ||
query_encoding = tokenizer.tokenize_query(query, return_tensors="pt") | ||
doc_encoding = tokenizer.tokenize_doc(documents, return_tensors="pt", padding=True, truncation=True) | ||
|
||
with torch.no_grad(): | ||
query_embedding = orig_model(**orig_query_encoding) | ||
doc_embedding = orig_model(**orig_doc_encoding) | ||
output = model.forward(query_encoding=query_encoding, doc_encoding=doc_encoding) | ||
|
||
assert torch.allclose(query_embedding, output.query_embeddings.embeddings, atol=1e-6) | ||
assert torch.allclose( | ||
doc_embedding[orig_doc_encoding.attention_mask.bool()], | ||
output.doc_embeddings.embeddings[doc_encoding.attention_mask.bool()], | ||
atol=1e-6, | ||
) |