diff --git a/fast_sentence_transformers/FastSentenceTransformer.py b/fast_sentence_transformers/FastSentenceTransformer.py index fce9559..f9194a7 100644 --- a/fast_sentence_transformers/FastSentenceTransformer.py +++ b/fast_sentence_transformers/FastSentenceTransformer.py @@ -289,7 +289,8 @@ def encode( embeddings = embeddings.cpu().detach().numpy() if normalize_embeddings: - embeddings = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=False) + norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) + embeddings = embeddings/np.where(norms<1e-12, 1e-12, norms) all_embeddings.extend(embeddings)