diff --git a/tests/test_unsloth_executor.py b/tests/test_unsloth_executor.py index 15b1c7c673..b62eac7214 100644 --- a/tests/test_unsloth_executor.py +++ b/tests/test_unsloth_executor.py @@ -54,6 +54,8 @@ def test_unsloth_rope(): B, nh, T, hs = 2, 32, 64, 16 cos, sin = build_rope_cache(T, hs, device="cuda") + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) q = torch.rand((B, nh, T, hs), device="cuda", requires_grad=True) def foo(x, cos, sin):