From bbc43f69d56d2651f8715e242ed91bf3c6dcb08f Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 5 Jan 2024 20:49:39 -0500 Subject: [PATCH] Update test_matmul.py --- test/ops/test_matmul.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/test/ops/test_matmul.py b/test/ops/test_matmul.py index b9001d3fc..42a0ab91c 100644 --- a/test/ops/test_matmul.py +++ b/test/ops/test_matmul.py @@ -39,7 +39,8 @@ def test_segment_matmul_autograd(dtype, device): @withCUDA @pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16]) -def test_grouped_matmul_autograd(dtype, device): +@pytest.mark.parametrize('transposed', [True, False]) +def test_grouped_matmul_autograd(dtype, transposed, device): if device.type == 'cuda' and dtype == torch.bfloat16: pytest.skip('CUDA does not support bfloat16') @@ -48,11 +49,19 @@ def test_grouped_matmul_autograd(dtype, device): torch.randn(6, 9, device=device, requires_grad=True), torch.randn(3, 32, device=device, requires_grad=True), ] - others = [ - torch.randn(16, 48, device=device, requires_grad=True), - torch.randn(9, 42, device=device, requires_grad=True), - torch.randn(32, 64, device=device, requires_grad=True), - ] + if transposed: + others_origin = [ + torch.randn(48, 16, device=device, requires_grad=True), + torch.randn(42, 9, device=device, requires_grad=True), + torch.randn(64, 32, device=device, requires_grad=True), + ] + others = [other.t() for other in others_origin] + else: + others = [ + torch.randn(16, 48, device=device, requires_grad=True), + torch.randn(9, 42, device=device, requires_grad=True), + torch.randn(32, 64, device=device, requires_grad=True), + ] biases = [ torch.randn(48, device=device, requires_grad=True), @@ -70,4 +79,7 @@ def test_grouped_matmul_autograd(dtype, device): sum([out.sum() for out in outs]).backward() for i in range(len(outs)): - assert others[i].grad.size() == others[i].size() + if transposed: + assert others_origin[i].grad.size() == others_origin[i].size() + else: + assert others[i].grad.size() == others[i].size()