From 5aa74f878cc0d8d7bbc623a3ced681dcb31955ec Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Sun, 15 Sep 2024 22:15:50 +0200 Subject: [PATCH] correct device for pos_emb (#60) --- i6_models/parts/conformer/mhsa_rel_pos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 12e174c1..c2cc6552 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -252,7 +252,7 @@ def _sinusoidal_pe(pos_seq: torch.Tensor, embed_dim: int): sinusoid_input = torch.outer(pos_seq, inv_freq) - pos_emb = torch.zeros(pos_seq.shape[0], embed_dim) + pos_emb = torch.zeros(pos_seq.shape[0], embed_dim, device=pos_seq.device) pos_emb[:, 0::2] = sinusoid_input.sin() pos_emb[:, 1::2] = sinusoid_input.cos()