Skip to content

Commit

Permalink
Merge pull request #4 from arrmansa/main
Browse files Browse the repository at this point in the history
Fixed dimension of normalized embeddings
  • Loading branch information
davidberenstein1957 authored Nov 14, 2022
2 parents aa90a62 + fdd0b97 commit aaff8fd
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion fast_sentence_transformers/FastSentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def encode(
embeddings = embeddings.cpu().detach().numpy()

if normalize_embeddings:
embeddings = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=False)
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
embeddings = embeddings/np.where(norms<1e-12, 1e-12, norms)

all_embeddings.extend(embeddings)

Expand Down

0 comments on commit aaff8fd

Please sign in to comment.