Skip to content

Commit

Permalink
Add bucketing to DeepSparseSentenceTransformer (#1334)
Browse files Browse the repository at this point in the history
* Support for SentenceTransformer with `deepsparse.sentence_transformers.SentenceTransformer`

* Format

* Update

* Address comments

* Add bucketing to SentenceTransformer

* Actually add bucketing

* Missed a spot

* Add benchmarking script

* Cleanup

* Add colab

* Cleanup tokenization

* Add alias for DeepSparseSentenceTransformer

* Update benchmark_encoding.py

* Format
  • Loading branch information
mgoin authored Oct 26, 2023
1 parent e699c8f commit fd447cf
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 12 deletions.
14 changes: 14 additions & 0 deletions src/deepsparse/sentence_transformers/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

# DeepSparse SentenceTransformers

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1sfN8zDK7MIyatiSIbt2xWh0i6GnaBnTR?usp=sharing)

```python
from deepsparse.sentence_transformers import SentenceTransformer
```
Expand Down Expand Up @@ -38,6 +40,18 @@ for sentence, embedding in zip(sentences, embeddings):
print("")
```

## Benchmarking Performance

There is a `benchmark_encoding.py` script located in this directory that compares a standard model running in both SentenceTransformers and DeepSparse, with a sparsified model in DeepSparse. Here is an example run on an 8 core SPR CPU with the base model being `BAAI/bge-small-en-v1.5`:
```bash
python benchmark_encoding.py --base_model BAAI/bge-small-en-v1.5 --sparse_model zeroshot/bge-small-en-v1.5-quant

[Standard SentenceTransformer] Encoded 100 sentences of length 700 in 10.42 seconds.
[DeepSparse] Encoded 100 sentences of length 700 in 4.04 seconds.
[DeepSparse Optimized] Encoded 100 sentences of length 700 in 1.82 seconds.
```


## Accuracy Validation with MTEB

DeepSparse's efficiency doesn't compromise its accuracy, thanks to testing with the Multilingual Text Embedding Benchmark (MTEB). This process validates the model's performance against standard tasks, ensuring its reliability.
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/sentence_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@
)


from .sentence_transformer import SentenceTransformer
from .sentence_transformer import DeepSparseSentenceTransformer, SentenceTransformer
105 changes: 105 additions & 0 deletions src/deepsparse/sentence_transformers/benchmark_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import random
import string
import time

import sentence_transformers
from deepsparse.sentence_transformers import DeepSparseSentenceTransformer


def generate_random_sentence(length=700):
# Generate a random sentence of a given length.
return "".join(
random.choices(
string.ascii_letters
+ string.digits
+ string.punctuation
+ string.whitespace,
k=length,
)
)


def benchmark_model(model, sentences):
# Benchmark the encoding time for a model with a given list of sentences.
start_time = time.time()
_ = model.encode(sentences)
elapsed_time = time.time() - start_time
return elapsed_time


def main(args):
# Generate a list of random sentences
sentences = [
generate_random_sentence(args.length) for _ in range(args.num_sentences)
]

# Load the models
standard_model = sentence_transformers.SentenceTransformer(args.base_model)
deepsparse_model = DeepSparseSentenceTransformer(args.base_model, export=True)
deepsparse_opt_model = DeepSparseSentenceTransformer(args.sparse_model)

# Benchmark sentence_transformers
standard_time = benchmark_model(standard_model, sentences)
print(
f"[Standard SentenceTransformer] Encoded {args.num_sentences} sentences "
f"of length {args.length} in {standard_time:.2f} seconds."
)

# Benchmark deepsparse.sentence_transformers
deepsparse_time = benchmark_model(deepsparse_model, sentences)
print(
f"[DeepSparse] Encoded {args.num_sentences} sentences of length "
f"{args.length} in {deepsparse_time:.2f} seconds."
)

# Benchmark deepsparse.sentence_transformers
deepsparse_opt_time = benchmark_model(deepsparse_opt_model, sentences)
print(
f"[DeepSparse Optimized]Encoded {args.num_sentences} sentences of length "
f"{args.length} in {deepsparse_opt_time:.2f} seconds."
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Sentence Transformer Models."
)
parser.add_argument(
"--base_model",
type=str,
default="BAAI/bge-small-en-v1.5",
help="Name of the standard model.",
)
parser.add_argument(
"--sparse_model",
type=str,
default="zeroshot/bge-small-en-v1.5-quant",
help="Name of the sparse model.",
)
parser.add_argument(
"--num_sentences",
type=int,
default=100,
help="Number of sentences to generate.",
)
parser.add_argument(
"--length", type=int, default=700, help="Length of each sentence."
)
args = parser.parse_args()

main(args)
101 changes: 90 additions & 11 deletions src/deepsparse/sentence_transformers/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from tqdm.autonotebook import trange
Expand All @@ -28,7 +28,7 @@
DEFAULT_MODEL_NAME = "zeroshot/bge-small-en-v1.5-quant"


class SentenceTransformer:
class DeepSparseSentenceTransformer:
"""
Loads or creates a SentenceTransformer-compatible model that can be used to map
text to embeddings.
Expand All @@ -42,6 +42,8 @@ class SentenceTransformer:
this should be set to 512 for most models. Any text that exceeds this
token length will be truncated.
:param use_auth_token: HuggingFace authentication token to download private models.
:param buckets: Create static buckets less than max_seq_length automaticly if True,
manually specified if a List of lengths are passed in, or fully dynamic if False
"""

def __init__(
Expand All @@ -50,16 +52,51 @@ def __init__(
export: bool = False,
max_seq_length: int = 512,
use_auth_token: Union[bool, str, None] = None,
buckets: Union[bool, List[int]] = True,
):

self.model_name_or_path = model_name_or_path
self.model = DeepSparseModelForFeatureExtraction.from_pretrained(
model_name_or_path, export=export, use_auth_token=use_auth_token
)
self.model.compile(batch_size=0)
self.tokenizer = get_preprocessor(model_name_or_path)

self._max_seq_length = max_seq_length
# TODO: support faster bulk execution with batch size > 1
self._static_batch_size = 1

self.dyn_model = DeepSparseModelForFeatureExtraction.from_pretrained(
model_name_or_path,
export=export,
use_auth_token=use_auth_token,
)
self.dyn_model.reshape(input_shapes="[0,0]")
self.dyn_model.compile(batch_size=0)

if buckets:
# Initialize a model for each bucket
self.buckets = [int(self._max_seq_length / 4 * i) for i in range(1, 5)]
self.models = {}
for bucket in self.buckets:
self.models[
bucket
] = DeepSparseModelForFeatureExtraction.from_pretrained(
model_name_or_path,
export=export,
use_auth_token=use_auth_token,
)
self.models[bucket].reshape(
input_shapes=f"[{self._static_batch_size},{bucket}]"
)
self.models[bucket].compile(batch_size=self._static_batch_size)
else:
self.buckets = None
self.models = None

def _select_bucket(self, seq_length: int) -> int:
"""
Selects the appropriate model based on the input sequence length.
"""
for bucket in self.buckets:
if seq_length <= bucket:
return bucket
# default to the maximum if seq_length exceeds all buckets
return self._max_seq_length

def encode(
self,
Expand Down Expand Up @@ -124,7 +161,25 @@ def encode(
sentences_batch = sentences_sorted[start_index : start_index + batch_size]

model_inputs = self.tokenize(sentences_batch)
model_output = self.model(**model_inputs)

if self.buckets and batch_size == 1:
# Use bucketing for batch size 1
# Select the model based on the bucketing logic
# TODO: tokenize ahead of time and simply add padding
seq_length = len(model_inputs[0])
selected_bucket = self._select_bucket(seq_length)

# Tokenize using the selected bucket size
model_inputs = self.tokenize(
sentences_batch, target_length=selected_bucket
)
model = self.models[selected_bucket]
else:
# Use dynamic shape
model = self.dyn_model

# Run the inference
model_output = model(**model_inputs)

out_features = {}
out_features["sentence_embedding"] = self.mean_pooling(
Expand Down Expand Up @@ -189,11 +244,31 @@ def _text_length(self, text: Union[List[int], List[List[int]]]) -> int:
else:
return sum([len(t) for t in text]) # Sum of length of individual strings

def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
def tokenize(
self,
texts: Union[List[str], List[Dict], List[Tuple[str, str]]],
target_length: Optional[int] = None,
) -> List[torch.Tensor]:
"""
Tokenizes the texts
"""
return self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
if target_length:
# Make sure to pad the tokens to the specified length
return self.tokenizer(
texts,
max_length=target_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
else:
# No padding needed
return self.tokenizer(
texts,
truncation=True,
max_length=self._max_seq_length,
return_tensors="pt",
)

def mean_pooling(
self, model_output: torch.Tensor, attention_mask: torch.Tensor
Expand All @@ -214,3 +289,7 @@ def mean_pooling(
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)


# for backwards compatibility
SentenceTransformer = DeepSparseSentenceTransformer

0 comments on commit fd447cf

Please sign in to comment.