Skip to content

Commit

Permalink
Fix shape auto-inference for matmul.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 29, 2023
1 parent 601bb6f commit c9d2a6c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lib/nnc/cmd/blas/ccv_nnc_blas.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ static void _ccv_nnc_gemm_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const
outputs[0].dim[nd - 2] = b_rows;
outputs[0].dim[nd - 1] = b_cols;
int i;
if (a_nd > w_nd)
for (i = 0; i < nd - 3; i++)
outputs[0].dim[i] = inputs[0].dim[i];
else
for (i = 0; i < nd - 3; i++)
outputs[0].dim[i] = inputs[1].dim[i];
for (i = 0; i < nd - 3; i++)
{
const int a_idx = a_nd - nd + i;
const int w_idx = w_nd - nd + i;
outputs[0].dim[i] = ccv_max(a_idx >= 0 ? inputs[0].dim[a_idx] : 1, w_idx >= 0 ? inputs[1].dim[w_idx] : 1);
}
}
}

Expand Down

0 comments on commit c9d2a6c

Please sign in to comment.