Skip to content

Commit

Permalink
Fix issue with mul_back that not check input size.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 5, 2024
1 parent de59d3e commit c8d8187
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions lib/nnc/cmd/blas/mps/ccv_nnc_mul_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ static int _ccv_nnc_mul_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,

const float p = cmd.info.blas.a[0];
const ccv_nnc_tensor_view_t* const g = (const ccv_nnc_tensor_view_t*)inputs[0] ? : 0;
const ccv_nnc_tensor_view_t* const b1 = (const ccv_nnc_tensor_view_t*)inputs[2];
ccv_nnc_tensor_view_t* const b2 = (ccv_nnc_tensor_view_t*)inputs[1];
const ccv_nnc_tensor_view_t* const b1 = (input_size >= 3) ? (const ccv_nnc_tensor_view_t*)inputs[2] : 0;
ccv_nnc_tensor_view_t* const b2 = (input_size >= 2) ? (ccv_nnc_tensor_view_t*)inputs[1] : 0;

ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)outputs[0];
ccv_nnc_tensor_view_t* const h = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0;
const int b2_nd = ccv_nnc_tensor_nd(b1->info.dim);
const int b1_nd = ccv_nnc_tensor_nd(b2->info.dim);
const int b2_nd = b2 ? ccv_nnc_tensor_nd(b2->info.dim) : 0;
const int b1_nd = b1 ? ccv_nnc_tensor_nd(b1->info.dim) : 0;
const int g_nd = ccv_max(b2_nd, b1_nd);
const int offset = CCV_NNC_MAX_DIM + 2 - g_nd;
if (a)
Expand All @@ -131,9 +131,10 @@ static int _ccv_nnc_mul_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,

@autoreleasepool {
NSMutableArray<NSNumber*>* mps_g_shape = [[NSMutableArray new] autorelease];
for (int i = offset; i < CCV_NNC_MAX_DIM + 2; i++){
for (int i = offset; i < CCV_NNC_MAX_DIM + 2; i++)
{
[mps_g_shape addObject:@(gdim[i])]; // still need mps_g_shape for target broadcast shape
gdim[i-offset] = gdim[i]; // move forward to align info.dim format
gdim[i - offset] = gdim[i]; // move forward to align info.dim format
}
const int* gdim_a = gdim;

Expand Down

0 comments on commit c8d8187

Please sign in to comment.