Skip to content

Commit

Permalink
fix for biencoder embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Dec 9, 2024
1 parent 91eadd3 commit 56f2f08
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def test_in_batch_loss_func(loss_func: InBatchLossFunction, output: LightningIRO
def test_regularization_loss_func(loss_func: RegularizationLossFunction, embeddings: torch.Tensor):
loss = loss_func.compute_loss(
BiEncoderOutput(
None, BiEncoderEmbedding(embeddings, torch.empty(0)), BiEncoderEmbedding(embeddings, torch.empty(0))
None,
BiEncoderEmbedding(embeddings, torch.empty(0), None),
BiEncoderEmbedding(embeddings, torch.empty(0), None),
)
)
assert loss >= 0
Expand Down

0 comments on commit 56f2f08

Please sign in to comment.