diff --git a/tests/test_enc_dec_att.py b/tests/test_enc_dec_att.py index 4ca16128..7ea35fe9 100644 --- a/tests/test_enc_dec_att.py +++ b/tests/test_enc_dec_att.py @@ -16,7 +16,7 @@ def test_additive_attention(): # pass key as weight feedback just for testing context, weights = att( - key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu" + key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len ) assert context.shape == (10, 5) assert weights.shape == (10, 20, 1) @@ -42,7 +42,6 @@ def test_encoder_decoder_attention_model(): output_dropout=0.1, zoneout_drop_c=0.0, zoneout_drop_h=0.0, - device="cpu", ) decoder = AttentionLSTMDecoderV1(decoder_cfg) target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] @@ -69,7 +68,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float): output_dropout=0.1, zoneout_drop_c=zoneout_drop_c, zoneout_drop_h=zoneout_drop_h, - device="cpu", ) decoder = AttentionLSTMDecoderV1(decoder_cfg) decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)