diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 5e952ea4f4d0d..7908e1db701a2 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1117,7 +1117,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ // Control the number of threadblocks by adjusting the maximum number of // threads per multi-processor. These numbers better reflect the maximum // theoretical achievable threads per MP for the reduction operation. - if (iter.ndim() == 1) + if (iter.ndim() == 1 || iter.ndim() == 3) max_threads_per_mp = 512; if (iter.ndim() == 2) max_threads_per_mp = 256;