From cc3c29a81a140f7b97045718fb88eb0664c37bd7 Mon Sep 17 00:00:00 2001 From: Yujia Zhai Date: Wed, 9 Oct 2024 12:33:27 -0700 Subject: [PATCH] CUTLASS 3.6.0 (#1850) * v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai Co-authored-by: Haicheng Wu --- CHANGELOG.md | 20 + CMakeLists.txt | 66 +- PUBLICATIONS.md | 2 + README.md | 70 +- cmake/CTestTestfile.configure.cmake | 2 - cmake/CTestTestfile.test.configure.cmake | 8 +- cmake/googletest.cmake | 3 +- examples/35_gemm_softmax/gemm_softmax.cu | 8 +- .../48_hopper_warp_specialized_gemm.cu | 12 +- .../gather_gemm.hpp | 2 +- .../53_hopper_gemm_permute/permute_traits.hpp | 21 +- .../55_hopper_int4_fp8_gemm.cu | 701 + .../55_hopper_mixed_dtype_gemm.cu | 154 +- .../55_hopper_mixed_dtype_gemm/CMakeLists.txt | 13 +- examples/55_hopper_mixed_dtype_gemm/README.md | 2 + .../packed_scale.hpp | 132 + .../56_hopper_ptr_array_batched_gemm.cu | 3 +- .../57_hopper_grouped_gemm.cu | 6 +- .../CMakeLists.txt | 4 + .../61_hopper_gemm_with_topk_and_softmax.cu | 534 + .../CMakeLists.txt | 32 + .../62_hopper_sparse_gemm.cu | 596 + examples/62_hopper_sparse_gemm/CMakeLists.txt | 36 + .../63_hopper_gemm_with_weight_prefetch.cu | 500 + .../CMakeLists.txt | 36 + .../README.md | 82 + .../collective/builder.hpp | 215 + .../collective/dispatch_policy_extra.hpp | 61 + ..._gmma_ss_warpspecialized_with_prefetch.hpp | 867 + .../gemm_with_weight_prefetch_commandline.hpp | 117 + ...gemm_tma_warpspecialized_with_prefetch.hpp | 561 + .../pipeline/prefetch_pipeline_sm90.hpp | 161 + examples/CMakeLists.txt | 3 + examples/cute/tutorial/tiled_copy.cu | 4 +- include/cute/algorithm/clear.hpp | 6 +- include/cute/algorithm/cooperative_copy.hpp | 12 +- include/cute/algorithm/cooperative_gemm.hpp | 4 +- include/cute/algorithm/copy.hpp | 12 +- include/cute/algorithm/functional.hpp | 7 +- include/cute/algorithm/prefetch.hpp | 8 +- include/cute/algorithm/tuple_algorithms.hpp | 130 +- include/cute/arch/cluster_sm90.hpp | 4 +- include/cute/arch/config.hpp | 50 + include/cute/arch/copy_sm50.hpp | 30 +- include/cute/arch/copy_sm90.hpp | 15 +- include/cute/arch/copy_sm90_desc.hpp | 53 +- include/cute/arch/copy_sm90_tma.hpp | 95 +- include/cute/arch/mma.hpp | 6 +- include/cute/arch/mma_sm90.hpp | 3825 +- include/cute/arch/mma_sm90_desc.hpp | 7 +- include/cute/arch/mma_sm90_gmma.hpp | 3420 +- include/cute/arch/mma_sm90_gmma_sparse.hpp | 53789 ++++++++++++++++ include/cute/arch/util.hpp | 23 +- include/cute/atom/copy_atom.hpp | 95 +- include/cute/atom/copy_traits_sm50.hpp | 19 +- include/cute/atom/copy_traits_sm90_im2col.hpp | 12 +- include/cute/atom/copy_traits_sm90_tma.hpp | 29 +- .../atom/copy_traits_sm90_tma_swizzle.hpp | 22 + include/cute/atom/mma_atom.hpp | 215 +- include/cute/atom/mma_traits.hpp | 113 +- include/cute/atom/mma_traits_sm90.hpp | 18 +- include/cute/atom/mma_traits_sm90_gmma.hpp | 4320 +- .../cute/atom/mma_traits_sm90_gmma_sparse.hpp | 16915 +++++ include/cute/config.hpp | 13 - include/cute/container/alignment.hpp | 20 +- include/cute/container/array_aligned.hpp | 4 +- include/cute/container/array_subbyte.hpp | 14 + include/cute/container/bit_field.hpp | 4 +- include/cute/container/cuda_types.hpp | 8 +- include/cute/container/tuple.hpp | 13 +- include/cute/container/type_list.hpp | 3 +- include/cute/int_tuple.hpp | 166 +- include/cute/layout.hpp | 160 +- include/cute/layout_composed.hpp | 6 +- include/cute/numeric/arithmetic_tuple.hpp | 12 +- include/cute/numeric/complex.hpp | 6 +- include/cute/numeric/int.hpp | 14 +- include/cute/numeric/integral_constant.hpp | 25 +- include/cute/numeric/integral_ratio.hpp | 9 +- include/cute/numeric/math.hpp | 18 +- include/cute/numeric/numeric_types.hpp | 72 +- include/cute/numeric/real.hpp | 18 + include/cute/pointer.hpp | 29 +- include/cute/pointer_base.hpp | 7 +- include/cute/pointer_flagged.hpp | 63 +- include/cute/pointer_sparse.hpp | 172 + include/cute/pointer_swizzle.hpp | 20 +- include/cute/stride.hpp | 131 +- include/cute/swizzle.hpp | 21 +- include/cute/swizzle_layout.hpp | 42 +- include/cute/tensor.hpp | 3 + include/cute/tensor_impl.hpp | 82 +- include/cute/tensor_predicate.hpp | 5 +- include/cute/tensor_zip.hpp | 243 + include/cute/underscore.hpp | 9 +- include/cute/util/print.hpp | 34 +- include/cute/util/type_traits.hpp | 11 +- include/cutlass/arch/barrier.h | 47 + include/cutlass/arch/config.h | 81 + .../cutlass/arch/grid_dependency_control.h | 84 + include/cutlass/arch/memory_sm80.h | 9 + include/cutlass/arch/mma_sm90.h | 25 +- include/cutlass/arch/reg_reconfig.h | 6 +- include/cutlass/arch/synclog.hpp | 1324 + include/cutlass/array.h | 215 +- include/cutlass/bfloat16.h | 155 +- include/cutlass/cluster_launch.hpp | 1 + ..._implicit_gemm_gmma_ss_warpspecialized.hpp | 184 +- include/cutlass/conv/convnd_problem_shape.hpp | 78 +- include/cutlass/conv/detail.hpp | 137 + .../conv/device/conv_universal_adapter.hpp | 17 +- .../cutlass/conv/device/direct_convolution.h | 2 + .../conv/device/implicit_gemm_convolution.h | 19 +- .../device/implicit_gemm_convolution_fusion.h | 1 + include/cutlass/conv/dispatch_policy.hpp | 6 +- .../cutlass/conv/kernel/conv_universal.hpp | 2 + ...sm90_implicit_gemm_tma_warpspecialized.hpp | 369 +- include/cutlass/cuda_host_adapter.hpp | 2 + include/cutlass/cutlass.h | 1 + include/cutlass/detail/collective.hpp | 1 - include/cutlass/detail/layout.hpp | 21 +- include/cutlass/detail/mma.hpp | 5 + include/cutlass/device_kernel.h | 12 + .../collective/builders/sm90_builder.inl | 36 +- .../collective/collective_builder.hpp | 5 +- .../collective/collective_epilogue.hpp | 8 + .../cutlass/epilogue/collective/detail.hpp | 11 +- .../collective/sm70_epilogue_vectorized.hpp | 366 +- .../sm70_epilogue_vectorized_array.hpp | 412 + ...m90_epilogue_array_tma_warpspecialized.hpp | 86 +- .../sm90_epilogue_tma_warpspecialized.hpp | 59 +- ...e_tma_warpspecialized_bias_elementwise.hpp | 9 +- include/cutlass/epilogue/dispatch_policy.hpp | 3 +- .../cutlass/epilogue/fusion/operations.hpp | 66 +- .../sm90_callbacks_tma_warpspecialized.hpp | 343 +- ...90_visitor_compute_tma_warpspecialized.hpp | 31 +- .../sm90_visitor_load_tma_warpspecialized.hpp | 470 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 428 +- .../sm90_visitor_tma_warpspecialized.hpp | 12 +- .../fusion/sm90_visitor_topk_softmax.hpp | 759 + include/cutlass/epilogue/thread/activation.h | 63 +- .../linear_combination_bias_elementwise.h | 119 +- .../threadblock/default_epilogue_tensor_op.h | 38 + include/cutlass/float8.h | 12 + include/cutlass/functional.h | 78 +- .../gemm/collective/builders/sm90_common.inl | 73 +- .../collective/builders/sm90_gmma_builder.inl | 67 +- .../builders/sm90_sparse_config.inl | 268 + .../builders/sm90_sparse_gmma_builder.inl | 388 + .../gemm/collective/collective_builder.hpp | 1 + .../gemm/collective/collective_mma.hpp | 1 + ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 5 +- ...mma_multistage_gmma_rs_warpspecialized.hpp | 2 +- ...mma_multistage_gmma_ss_warpspecialized.hpp | 2 +- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 2 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 719 +- .../sm90_mma_tma_gmma_ss_warpspecialized.hpp | 2 +- ...90_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 2 +- ...sparse_mma_tma_gmma_ss_warpspecialized.hpp | 724 + include/cutlass/gemm/device/base_grouped.h | 1 + .../gemm/device/default_gemm_configuration.h | 90 +- include/cutlass/gemm/device/ell_gemm.h | 1 + include/cutlass/gemm/device/gemm.h | 1 + include/cutlass/gemm/device/gemm_array.h | 1 + include/cutlass/gemm/device/gemm_batched.h | 1 + include/cutlass/gemm/device/gemm_complex.h | 1 + include/cutlass/gemm/device/gemm_sparse.h | 1 + .../gemm/device/gemm_sparse_with_absmax.h | 1 + .../gemm/device/gemm_splitk_parallel.h | 1 + .../gemm/device/gemm_universal_adapter.h | 60 +- .../cutlass/gemm/device/gemm_universal_base.h | 2 + include/cutlass/gemm/device/gemv.h | 1 + include/cutlass/gemm/device/rank_2k.h | 1 + include/cutlass/gemm/device/rank_k.h | 1 + include/cutlass/gemm/device/symm.h | 1 + include/cutlass/gemm/device/trmm.h | 1 + include/cutlass/gemm/dispatch_policy.hpp | 30 +- include/cutlass/gemm/kernel/sm70_gemm.hpp | 2 + ..._array_tma_warpspecialized_cooperative.hpp | 55 +- ...emm_array_tma_warpspecialized_pingpong.hpp | 31 +- include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 2 + .../kernel/sm90_gemm_tma_warpspecialized.hpp | 162 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 104 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 44 +- .../gemm/kernel/sm90_gemm_warpspecialized.hpp | 4 +- .../sm90_gemm_warpspecialized_cooperative.hpp | 17 +- .../sm90_gemm_warpspecialized_pingpong.hpp | 7 +- .../gemm/kernel/sm90_tile_scheduler.hpp | 18 +- .../gemm/kernel/sm90_tile_scheduler_group.hpp | 17 +- .../kernel/sm90_tile_scheduler_stream_k.hpp | 99 +- .../gemm/kernel/static_tile_scheduler.hpp | 112 +- .../cutlass/gemm/kernel/tile_scheduler.hpp | 40 +- .../gemm/kernel/tile_scheduler_params.h | 9 +- .../gemm/warp/default_mma_tensor_op_sm80.h | 35 +- .../gemm/warp/mma_mixed_input_tensor_op.h | 12 +- include/cutlass/gemm/warp/mma_tensor_op.h | 2 +- include/cutlass/kernel_launch.h | 68 + include/cutlass/numeric_conversion.h | 604 +- include/cutlass/pipeline/sm90_pipeline.hpp | 24 +- include/cutlass/platform/platform.h | 1 - .../cutlass/reduction/device/reduce_split_k.h | 27 +- .../device/tensor_reduce_affine_contiguous.h | 1 + .../device/tensor_reduce_affine_strided.h | 1 + include/cutlass/subbyte_reference.h | 6 +- include/cutlass/tensor_ref.h | 3 +- .../device/transform_universal_adapter.hpp | 167 +- .../kernel/filter_format_transformer.hpp | 28 +- .../kernel/sm90_sparse_gemm_compressor.hpp | 578 + .../kernel/sparse_gemm_compressor.hpp | 284 + include/cutlass/uint128.h | 7 +- include/cutlass/version.h | 4 +- media/docs/dependent_kernel_launch.md | 32 + media/docs/profiler.md | 89 +- media/docs/programming_guidelines.md | 7 +- media/docs/utilities.md | 48 + pyproject.toml | 2 +- python/cutlass/__init__.py | 2 +- python/cutlass/backend/epilogue.py | 6 +- .../cutlass/backend/evt/backend/sm90_nodes.py | 4 +- python/cutlass/emit/pytorch.py | 4 +- python/cutlass_library/conv3x_emitter.py | 3 + python/cutlass_library/gemm_operation.py | 20 +- python/cutlass_library/generator.py | 1683 +- python/cutlass_library/library.py | 2 + python/cutlass_library/manifest.py | 17 +- python/cutlass_library/sm90_shapes.py | 212 + python/cutlass_library/sm90_utils.py | 601 + python/setup_library.py | 2 +- python/setup_pycute.py | 2 +- test/self_contained_includes/CMakeLists.txt | 95 + test/unit/CMakeLists.txt | 2 +- test/unit/conv/cache_testbed_output.h | 2 +- test/unit/conv/device/conv2d_testbed.h | 14 +- test/unit/conv/device/conv3d_testbed.h | 8 +- .../conv/device_3x/conv_problem_sizes.hpp | 145 +- ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 18 +- ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 19 + ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 22 +- ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 17 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 17 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 16 + ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 18 +- ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 16 + ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 16 + test/unit/conv/device_3x/testbed_conv.hpp | 15 +- ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 17 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + test/unit/core/fast_numeric_conversion.cu | 3 + test/unit/core/functional.cu | 26 +- test/unit/core/numeric_conversion.cu | 126 + test/unit/cute/ampere/cooperative_copy.cu | 2 + test/unit/cute/ampere/cooperative_gemm.cu | 2 + test/unit/cute/ampere/tiled_cp_async.cu | 1 + test/unit/cute/core/CMakeLists.txt | 3 +- test/unit/cute/core/composition.cpp | 9 +- test/unit/cute/core/domain_distribute.cpp | 7 +- test/unit/cute/core/int_tuple.cpp | 178 +- test/unit/cute/core/inverse_left.cpp | 6 +- test/unit/cute/core/inverse_right.cpp | 5 +- test/unit/cute/core/math.cpp | 10 + test/unit/cute/core/swizzle_layout.cpp | 116 + test/unit/cute/hopper/cooperative_gemm.cu | 2 + test/unit/cute/layout/layout_operator.cu | 3 + test/unit/cute/volta/cooperative_gemm.cu | 2 + test/unit/gemm/device/CMakeLists.txt | 28 +- test/unit/gemm/device/gemm_testbed_3x.hpp | 805 +- test/unit/gemm/device/gemm_testbed_3x_evt.hpp | 1060 +- .../gemm/device/gemm_testbed_3x_ptr_array.hpp | 51 +- .../gemm_testbed_3x_tensor_broadcast.hpp | 8 +- ...8n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 16 +- ...6n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 16 +- test/unit/gemm/device/sm90_evt_operations.hpp | 248 +- ...er_warpspecialized_cooperative_aux_load.cu | 54 + ...r_warpspecialized_cooperative_aux_store.cu | 64 + ...ster_warpspecialized_cooperative_reduce.cu | 6 +- ...cluster_warpspecialized_pingpong_reduce.cu | 6 +- ...mm_f16_f16_f16_tensor_op_f32_group_gemm.cu | 126 + ...6_f16_tensor_op_f32_group_gemm_pingpong.cu | 65 +- ...16_f16_tensor_op_f32_ptr_array_pingpong.cu | 2 +- ..._rs_cluster_warpspecialized_cooperative.cu | 12 +- .../sm90_gemm_f32_f32_f32_tensor_op_f32.cu | 6 +- ..._rs_cluster_warpspecialized_cooperative.cu | 12 +- .../sm90_gemm_f8_f8_f8_tensor_op_fp32_evt.cu | 58 + .../device/sm90_gemm_stream_k_scheduler.cu | 2 +- ...0_sparse_gemm_f16_f16_f32_tensor_op_f32.cu | 255 + ...m90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu | 216 + ...m90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu | 216 + ...sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu | 216 + test/unit/gemm/device/testbed.h | 96 +- .../gemm/device/testbed_gemm_with_broadcast.h | 2 +- .../gemm/device/testbed_gemm_with_reduction.h | 2 +- test/unit/gemm/device/testbed_universal.h | 1 + test/unit/gemm/threadblock/mma_multistage.cu | 14 - test/unit/gemm/warp/gemm_sm80.cu | 209 - test/unit/transform/device/CMakeLists.txt | 58 + .../device/sm90_sparse_gemm_compressor_f16.cu | 95 + .../device/sm90_sparse_gemm_compressor_f32.cu | 95 + .../device/sm90_sparse_gemm_compressor_f8.cu | 95 + .../sm90_sparse_gemm_compressor_legacy.hpp | 480 + .../device/testbed_sparse_gemm_compressor.hpp | 876 + tools/library/CMakeLists.txt | 4 +- .../include/cutlass/library/arch_mappings.h | 6 + .../library/include/cutlass/library/library.h | 12 +- tools/library/src/conv_operation_3x.hpp | 4 +- tools/library/src/gemm_operation_3x.hpp | 12 +- .../src/reference/gemm_fp_mixed_input.cu | 4 +- tools/library/src/reference/gemm_fp_other.cu | 8 + .../src/reference/gemm_int_mixed_input.cu | 4 +- tools/library/src/reference/gemm_s8_s8_s32.cu | 146 + ...mm_int8_canonical.cu => gemm_u8_u8_s32.cu} | 86 +- .../initialize_reference_operations.cu | 7 +- .../library/src/sparse_gemm_operation_3x.hpp | 445 + tools/library/src/util.cu | 1 + .../include/cutlass/profiler/cublas_helpers.h | 122 +- .../cutlass/profiler/cutlass_profiler.h | 7 +- .../cutlass/profiler/device_allocation.h | 53 +- .../include/cutlass/profiler/device_context.h | 46 +- .../include/cutlass/profiler/options.h | 31 +- .../profiler/src/conv2d_operation_profiler.cu | 229 +- .../profiler/src/conv3d_operation_profiler.cu | 208 +- tools/profiler/src/cublas_helpers.cu | 275 +- tools/profiler/src/cutlass_profiler.cu | 13 - tools/profiler/src/device_allocation.cu | 408 +- tools/profiler/src/device_context.cu | 79 +- tools/profiler/src/gemm_operation_profiler.cu | 113 +- tools/profiler/src/operation_profiler.cu | 12 +- tools/profiler/src/options.cu | 180 +- tools/profiler/src/performance_report.cpp | 42 +- .../src/rank_2k_operation_profiler.cu | 114 +- .../profiler/src/rank_k_operation_profiler.cu | 113 +- .../src/sparse_gemm_operation_profiler.cu | 123 +- tools/profiler/src/symm_operation_profiler.cu | 127 +- tools/profiler/src/trmm_operation_profiler.cu | 128 +- .../util/include/cutlass/util/device_memory.h | 43 +- .../util/include/cutlass/util/distribution.h | 12 +- tools/util/include/cutlass/util/host_tensor.h | 11 +- .../include/cutlass/util/packed_stride.hpp | 2 + .../util/reference/device/tensor_fill.h | 127 +- .../cutlass/util/reference/host/conv.hpp | 3 +- .../cutlass/util/reference/host/gett.hpp | 6 +- .../cutlass/util/reference/host/tensor_fill.h | 243 +- 354 files changed, 105914 insertions(+), 8174 deletions(-) create mode 100644 examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu create mode 100644 examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp create mode 100644 examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu create mode 100644 examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt create mode 100644 examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu create mode 100644 examples/62_hopper_sparse_gemm/CMakeLists.txt create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/README.md create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp create mode 100644 include/cute/arch/config.hpp create mode 100644 include/cute/arch/mma_sm90_gmma_sparse.hpp create mode 100644 include/cute/atom/mma_traits_sm90_gmma_sparse.hpp create mode 100644 include/cute/pointer_sparse.hpp create mode 100644 include/cute/tensor_zip.hpp create mode 100644 include/cutlass/arch/config.h create mode 100644 include/cutlass/arch/grid_dependency_control.h create mode 100644 include/cutlass/arch/synclog.hpp create mode 100644 include/cutlass/conv/detail.hpp create mode 100644 include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp create mode 100644 include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp create mode 100644 include/cutlass/gemm/collective/builders/sm90_sparse_config.inl create mode 100644 include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl create mode 100644 include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp create mode 100644 include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp create mode 100644 include/cutlass/transform/kernel/sparse_gemm_compressor.hpp create mode 100644 media/docs/dependent_kernel_launch.md create mode 100644 python/cutlass_library/sm90_shapes.py create mode 100644 python/cutlass_library/sm90_utils.py create mode 100644 test/unit/cute/core/swizzle_layout.cpp create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu create mode 100644 test/unit/transform/device/CMakeLists.txt create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_f16.cu create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_f32.cu create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_f8.cu create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp create mode 100644 test/unit/transform/device/testbed_sparse_gemm_compressor.hpp create mode 100644 tools/library/src/reference/gemm_s8_s8_s32.cu rename tools/library/src/reference/{gemm_int8_canonical.cu => gemm_u8_u8_s32.cu} (65%) create mode 100644 tools/library/src/sparse_gemm_operation_3x.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index c784107be9..c98cdb515f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # NVIDIA CUTLASS Changelog +## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03) + +- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). + + [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu) + + [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu) + + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) +- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. +- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. +- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). +- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) +- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). + Various improvements and fixed from the community and CUTLASS team. Thanks to everyone who submitted PRs! + ## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25) - [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7419bdf5e5..e61b66a877 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,6 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests") set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest") - set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.") if (CUTLASS_USE_PACKED_TUPLE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1) @@ -234,7 +233,6 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") -set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API.") ################################################################################ # @@ -271,6 +269,7 @@ set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of opera set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") +set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") ################################################################################ @@ -318,6 +317,8 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) endif() +set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace") + # # NOTE: running with asan and CUDA requires the following environment variable: # @@ -345,6 +346,10 @@ if(CUTLASS_NVCC_EMBED_PTX) list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all) endif() +if (CUTLASS_SKIP_REDUCTION_INIT) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SKIP_REDUCTION_INIT=1) +endif() + if (CUTLASS_ENABLE_TENSOR_CORE_MMA) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() @@ -354,6 +359,18 @@ if (CUTLASS_PROFILER_DISABLE_REFERENCE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_PROFILER_DISABLE_REFERENCE=1) endif() +if (CUTLASS_ENABLE_GDC_FOR_SM90) + message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1) +endif() + +set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!") + +if (CUTLASS_ENABLE_SYNCLOG) + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) + string(APPEND CMAKE_CXX_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1") + string(APPEND CMAKE_CUDA_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1") +endif() @@ -880,12 +897,27 @@ function(cutlass_add_executable_tests NAME TARGET) set(TEST_GROUP_NAME ${NAME}) + # To run the tests from an install package with tests enabled, we need to generate test files + # that don't rely on the current directory structure in build. + + set(TEST_NAME c${NAME}) + set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) + file(MAKE_DIRECTORY ${TEST_GEN_DIR}) + + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT ON) + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) + + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format. + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY) + foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS) if (CMD_COUNT GREATER 1) - string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TEST_NAME) + string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TESTCASE_NAME) else() - string(TOLOWER "${NAME}" TEST_NAME) + string(TOLOWER "${NAME}" TESTCASE_NAME) endif() # The following rigmarole is needed to deal with spaces and possible quotes in @@ -899,7 +931,7 @@ function(cutlass_add_executable_tests NAME TARGET) separate_arguments(TEST_COMMAND_OPTIONS) add_custom_target( - ${TEST_NAME} + ${TESTCASE_NAME} COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${TEST_COMMAND_OPTIONS} DEPENDS @@ -907,34 +939,20 @@ function(cutlass_add_executable_tests NAME TARGET) ) if (CMD_COUNT GREATER 1) - add_dependencies(${NAME} ${TEST_NAME}) + add_dependencies(${NAME} ${TESTCASE_NAME}) endif() foreach(DEPENDEE ${__DEPENDEES}) - add_dependencies(${DEPENDEE} ${TEST_NAME}) + add_dependencies(${DEPENDEE} ${TESTCASE_NAME}) endforeach() - set(TEST_NAME c${TEST_NAME}) + set(TESTCASE_NAME c${TESTCASE_NAME}) string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY) - string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}") + file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" "${_TEST_CODE}") + file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" "${_TEST_CODE}") endforeach() - # To run the tests from an install package with tests enabled, we need to generate test files - # that don't rely on the current directory structure in build. - - set(TEST_NAME c${NAME}) - set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) - file(MAKE_DIRECTORY ${TEST_GEN_DIR}) - - set(TEST_EXE_PATH $) - set(TEST_USE_EXTENDED_FORMAT ON) - configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) - - set(TEST_EXE_PATH $) - set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format. - configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY) - # The following line imports the tests for immediate run via `make test`. include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake) diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index 04d4cd0a14..b7425f251d 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -26,6 +26,8 @@ - ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023. +- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023. + - ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023. ## 2022 diff --git a/README.md b/README.md index 1426e8a42e..e61335f240 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.5.1 +# CUTLASS 3.6.0 -_CUTLASS 3.5.1 - July 2024_ +_CUTLASS 3.6.0 - October 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -42,48 +42,26 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -# What's New in CUTLASS 3.5 - -CUTLASS 3.5.1 is an update to CUTLASS adding: - -- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu). -- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) -- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and -[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). -- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). -- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inference. -- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. -- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). -- Support for residual add (beta != 0) in convolution kernels. -- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. -- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). -- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). -- Better support for MSVC as a host compiler. -- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. -- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. -- NOTICE: - + Upcoming CUTLASS 3.6 release will include a breaking refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` API to bring it in line with `gemm::GemmUniversal`. After this, the 3.x convolution API will no longer be considered as a beta API. - + Upcoming CUTLASS 3.6 release will include a breaking refactor to the Hopper TMA pointer array batched epilogue in order to support grouped GEMMs. - -CUTLASS 3.5.0 is an update to CUTLASS adding: - -- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp). - + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). - + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). - + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms. - + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. - + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! -- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x. - + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. - + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. -- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. -- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). -- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). -- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. -- Fixes to greatly reduce build warnings. -- Updates and bugfixes from the community (thanks!) -- CUTLASS 3.5.1 is a minor update to CUTLASS containing small bug fixes and improvements, including fixes for FlashAttention-2 builds. +# What's New in CUTLASS 3.6 + +CUTLASS 3.6.0 is an update to CUTLASS adding: + +- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). + + [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu) + + [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu) + + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) +- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. +- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. +- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). +- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) +- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). Minimum requirements: @@ -163,7 +141,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). -The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error. +The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error. ``` cmake .. -DCUTLASS_NVCC_ARCHS="90a" @@ -191,6 +169,8 @@ CUTLASS is described in the following documents and the accompanying - [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory - [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application - [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development +- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +kernels in the same stream, and how it is used in CUTLASS. # Resources We have also described the structure of an efficient GEMM in our talk at the diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake index 94394a5000..611b3d181f 100644 --- a/cmake/CTestTestfile.configure.cmake +++ b/cmake/CTestTestfile.configure.cmake @@ -50,5 +50,3 @@ if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) else() set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@) endif() - -@_INLINE_PER_TEST_CODE@ diff --git a/cmake/CTestTestfile.test.configure.cmake b/cmake/CTestTestfile.test.configure.cmake index fa2ceeb9bd..31dba54498 100644 --- a/cmake/CTestTestfile.test.configure.cmake +++ b/cmake/CTestTestfile.test.configure.cmake @@ -30,14 +30,14 @@ if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT) # The longform/extended format allows generator expressions to be # expanded property and is useful in contexts where the files need # to be immediately included into being-processed cmake code. - add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) + add_test(NAME @TESTCASE_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) else() - add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) + add_test(@TESTCASE_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) endif() if (TEST_EXE_WORKING_DIRECTORY) - set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}") + set_tests_properties(@TESTCASE_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}") endif() -set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) +set_tests_properties(@TESTCASE_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 0350fb2dd1..d220cfadc2 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -34,9 +34,10 @@ if(GOOGLETEST_DIR) set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") endif() +set(GTEST_REPOSITORY "https://github.com/google/googletest.git" CACHE STRING "GoogleTest repo to fetch") FetchContent_Declare( googletest - GIT_REPOSITORY https://github.com/google/googletest.git + GIT_REPOSITORY ${GTEST_REPOSITORY} GIT_TAG v1.14.0 ) diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 27156ea02d..731e37b4d9 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -42,7 +42,8 @@ #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/gemm/device/gemm_complex.h" - +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_size.h" #include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" @@ -56,6 +57,7 @@ #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/error_metrics.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes #include "cutlass/layout/matrix.h" #include "cutlass/epilogue/thread/linear_combination.h" @@ -657,7 +659,9 @@ struct Testbed { } int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; - int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); + int64_t bytes = cutlass::bits_to_bytes( + (cutlass::sizeof_bits::value * 2 + cutlass::sizeof_bits::value) * + options.problem_size.m() * options.problem_size.n()); double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index f26f4da37d..164c785e01 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -303,14 +303,14 @@ bool initialize_block( int bits_input = cutlass::sizeof_bits::value; if (bits_input == 1) { - scope_max = 2; - scope_min = 0; + scope_max = Element(2); + scope_min = Element(0); } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; + scope_max = Element(2); + scope_min = Element(-2); } else { - scope_max = 8; - scope_min = -8; + scope_max = Element(8); + scope_min = Element(-8); } cutlass::reference::device::BlockFillRandomUniform( diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 57053b0f9a..c71109aa79 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -111,7 +111,7 @@ class GemmGather EpilogueTensorStorage epilogue; } tensors; - struct PipelineStorage : cute::aligned_struct<16> { + struct PipelineStorage : cute::aligned_struct<16, _2> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; diff --git a/examples/53_hopper_gemm_permute/permute_traits.hpp b/examples/53_hopper_gemm_permute/permute_traits.hpp index 96fcc64cf9..4c5baccac5 100644 --- a/examples/53_hopper_gemm_permute/permute_traits.hpp +++ b/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -50,7 +50,7 @@ struct PermuteTraits {}; using X = Underscore; // Reshape a rank-2 shape into a multidimensional shape. -// Input: +// Input: // shape = (A, B, ...) // target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...) // Output: @@ -76,12 +76,12 @@ reshape(Shape const& shape, TargetShape const& target_shape) // - sub-modes corresponding to the implied multidimensional shape of the source tensor // - strides accounting for the permutation operation being performed template -constexpr auto +constexpr auto make_permute_layout(Layout const& layout) { static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); if constexpr (Transpose) { // Deal with tensor B by transposing appropriately before and after computing the permute layout. - // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. return select<1,0,2>(make_permute_layout(select<1,0,2>(layout))); } else { @@ -129,23 +129,24 @@ inverse(Permutation const & perm) { template using inverse_t = decltype(inverse(T{})); -// Given a rank-2 layout of tensor that is assumed to have been permuted, +// Given a rank-2 layout of tensor that is assumed to have been permuted, // compute the original rank-2 layout of the tensor prior to the permutation. -// This is needed to form the correct input to the standalone permutation kernel. +// This is needed to form the correct input to the standalone permutation kernel. template -constexpr auto +constexpr auto make_original_layout(Layout const& layout) { static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); if constexpr (Transpose) { // Deal with tensor B by transposing appropriately before and after computing the permute layout. - // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. return select<1,0,2>(make_original_layout(select<1,0,2>(layout))); } else { using ShapeProfile = typename PermuteTraits::ShapeProfile; + auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{})); using IndexOrder = typename PermuteTraits::IndexOrder; + auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get(re_shape); }); using OrigOrder = conditional_t(), seq<0,1,2>, seq<1,0,2>>; - auto orig_shape = select(flatten(reshape(layout.shape(), ShapeProfile{})), IndexOrder{}); // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n"); // print("Original shape: "); print(orig_shape); print("\n"); return make_ordered_layout(product_each(orig_shape), OrigOrder{}); @@ -202,7 +203,7 @@ struct PermuteTraits> }; template -struct PermuteTraits> +struct PermuteTraits> { static constexpr bool kBatched = true; using ShapeProfile = Shape>, Shape, Shape>; @@ -222,7 +223,7 @@ struct PermuteTraits> }; template -struct PermuteTraits> +struct PermuteTraits> { static constexpr bool kBatched = true; using ShapeProfile = Shape, Shape>, Shape>; diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu new file mode 100644 index 0000000000..138f7a0402 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -0,0 +1,701 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform INT4 x FP8 GEMM and scale up the INT4 weight during dequantization. It uses a look-up table to avoid the multiplications + between INT4 and FP8. To trigger this method, use cutlass::Array as the scale type in the collective's arguments. + + However, this algorithm requires changes to the encoding of INT4 weights and scale factors. These changes must happen before launching the GEMM. See the helper functions + `unify_quant_encoding`, `initialize_packed_scale`, and header `fp8_packed_scale.hpp` for details. + + In a nutshell, the positive values of INT4 weights need to be encoded in the same way as negative values except for the sign bit. For each scale factor, + 8 negative results (-8 x scale, -7 x scale, ... -1 x scale) are packed together, forming a cutlass::Array value. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size + equal to the gemm problem K. + + Limitations: + 1) Only supports INT4 x { FP8, INT8, UINT8 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported. + 2) The INT4 weights and scale factors have additional encoding requirements. + 3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 4) The scales must have the same layout and groupsize. + 5) The groupsize must be greater or equal to the tile shape k. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "unfused_weight_dequantize.hpp" +#include "packed_scale.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using ElementScale = MmaType; +using ElementZero = ElementScale; // only for verify +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple >, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_modified; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation> block_scale_packed; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +bool unify_quant_encoding( + cutlass::DeviceAllocation const& block_in, + cutlass::DeviceAllocation& block_out) { + + using StorageType = cutlass::int4b_t::Storage; + + if (block_in.size() != block_out.size()) { + std::cerr << "block_in and block_out must have same size.\n"; + return false; + } + constexpr int pack = sizeof_bits_v / 4; + std::vector data(block_in.size() / pack); + cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack); + + for (auto&& d : data) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; ++i) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2); + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + Options const& options) { + + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + float const max_dequant_val = 4.f; + float const min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + return true; +} + +bool initialize_packed_scale( + cutlass::DeviceAllocation const& block_in, + cutlass::DeviceAllocation > & block_out) { + + std::vector data_in(block_in.size()); + std::vector > data_out(block_in.size()); + try { + block_in.copy_to_host(data_in.data()); + } catch (cutlass::cuda_exception const& e) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + for (size_t i = 0; i < block_in.size(); ++i) + { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + // std::cout << data_in[i] << ":" << std::hex << static_cast(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast((-data_in[i]).storage) << std::endl; + } + try { + block_out.copy_from_host(data_out.data()); + } catch (cutlass::cuda_exception const& e) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + Options const& options) { + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + + auto shape_b = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_modified.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_scale_packed.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + unify_quant_encoding(block_B, block_B_modified); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_packed_scale(block_scale, block_scale_packed); + initialize_zero(block_zero, options); + + auto layout_B = make_layout(shape_b, stride_B); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +Args args_from_options(Options const& options) +{ +// Swap the A and B tensors, as well as problem shapes here. + + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + // In this example, we use the GPU default kernels as a reference (unfused scale). + // This avoids numerical differences due to different accumulation order. + + // Again, due to numerical differences, we must use fast acc here when the mma type is + // FP8 as the fused implementation only supports fast acc at the moment. + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index 28baae260c..8a99cc2754 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -53,14 +53,18 @@ equal to the gemm problem K. Limitations: - 1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}. - 2) The narrow type must always be in K-major format. - 3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. - 4) The scales and the zeros must have the same layout and groupsize. + 1) The narrow type must always be in K-major format. + 2) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 3) The scales and the zeros must have the same layout and groupsize. + 4) The groupsize must be greater or equal to tile shape k. 5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format. 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + 2) Try avoid using scale or zero mode cause the computations will be the bottleneck. Examples: @@ -94,11 +98,8 @@ #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" #include "helper.h" #include "unfused_weight_dequantize.hpp" @@ -117,8 +118,8 @@ enum GemmMode { ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// -using MmaType = cutlass::float_e4m3_t; -using QuantType = cutlass::int4b_t; +using MmaType = cutlass::half_t; +using QuantType = cutlass::float_e4m3_t; constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; // A matrix configuration @@ -154,8 +155,8 @@ using ElementAccumulator = float; // E using ElementCompute = float; // Element type for epilogue computation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_128,_256,cute::Int>; // Threadblock-level tile size -using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; @@ -268,14 +269,14 @@ using StrideS_ref = cutlass::detail::TagToStrideB_t; StrideS stride_S; StrideS_ref stride_S_ref; -cutlass::HostTensor tensor_A; -cutlass::HostTensor tensor_B; -cutlass::HostTensor tensor_B_dq; -cutlass::HostTensor tensor_scale; -cutlass::HostTensor tensor_zero; -cutlass::HostTensor tensor_C; -cutlass::HostTensor tensor_D; -cutlass::HostTensor tensor_ref_D; +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) @@ -290,7 +291,7 @@ struct Options { float alpha = 1.0f; float beta = 0.0f; - int iterations = 1000; + int iterations = 10; int mode = 2; int m = 5120, n = 4096, k = 4096; int g = 128; @@ -368,9 +369,9 @@ struct Result ///////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to initialize a block of device data -template +template bool initialize_tensor( - cutlass::TensorView view, + cutlass::DeviceAllocation& block, uint64_t seed=2023) { double scope_max, scope_min; @@ -393,34 +394,35 @@ bool initialize_tensor( scope_max = 8; scope_min = -8; } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); return true; } -template +template bool initialize_quant_tensor( - cutlass::TensorView view, + cutlass::DeviceAllocation& block, uint64_t seed=2023) { float scope_min = float(cutlass::platform::numeric_limits::lowest()); float scope_max = float(cutlass::platform::numeric_limits::max()); - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); return true; } -template +template bool initialize_scale( - cutlass::TensorView view, - const Options &options) { + cutlass::DeviceAllocation& block, + Options const& options) { if (options.mode == GemmMode::ConvertOnly) { // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. - cutlass::reference::host::TensorFill(view, Element(1.0f)); + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); } else { float elt_max_f = float(cutlass::platform::numeric_limits::max()); @@ -430,32 +432,33 @@ bool initialize_scale( float scope_max(max_dequant_val / elt_max_f); float scope_min(min_dequant_val / elt_max_f); - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); } return true; } -template +template bool initialize_zero( - cutlass::TensorView view, - const Options &options) { + cutlass::DeviceAllocation& block, + Options const& options) { if (options.mode == GemmMode::ScaleWithZeroPoint) { - cutlass::reference::host::TensorFillRandomUniform( - view, seed, 2.0f, -2.0f); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); } else { // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. - cutlass::reference::host::TensorFill(view, Element(0.0f)); + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); } return true; } /// Initialize operands to be used in the GEMM and reference GEMM -void initialize(const Options &options) { +void initialize(Options const& options) { auto shape_b = cute::make_shape(options.n, options.k, options.l); - const int scale_k = (options.k + options.g - 1) / options.g; + int const scale_k = (options.k + options.g - 1) / options.g; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); // Reverse stride here due to swap and transpose @@ -469,27 +472,21 @@ void initialize(const Options &options) { auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); - tensor_A.resize(a_coord); - tensor_B.resize(b_coord); - tensor_B_dq.resize(b_coord); - tensor_C.resize(c_coord); - tensor_D.resize(c_coord); - tensor_ref_D.resize(c_coord); - - tensor_scale.resize({scale_k * options.l, options.n}); - tensor_zero.resize({scale_k * options.l, options.n}); + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); - initialize_tensor(tensor_A.host_view(), seed + 2022); - initialize_quant_tensor(tensor_B.host_view(), seed + 2021); - initialize_tensor(tensor_C.host_view(), seed + 2020); - initialize_scale(tensor_scale.host_view(), options); - initialize_zero(tensor_zero.host_view(), options); + block_scale.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_scale.sync_device(); - tensor_zero.sync_device(); + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); auto layout_B = make_layout(shape_b, stride_B); @@ -498,37 +495,36 @@ void initialize(const Options &options) { stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); - dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g); - tensor_B_dq.sync_host(); + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); } /// Populates a Gemm::Arguments structure from the given commandline options template -Args args_from_options(const Options &options) +Args args_from_options(Options const& options) { // Swap the A and B tensors, as well as problem shapes here. if (options.mode == GemmMode::ConvertOnly) { return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + {block_B.get(), stride_B, block_A.get(), stride_A}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } else if (options.mode == GemmMode::ScaleOnly) { return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } else if (options.mode == GemmMode::ScaleWithZeroPoint) { return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } else { std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; @@ -542,7 +538,7 @@ bool verify(const Options &options) { // // In this example, we use the GPU default kernels as a reference (unfused scale) - // This is to avoid numerical differences from different accumulation order. + // This avoids numerical differences due to different accumulation order. // Again, due to numerical differences, we must use fast acc here when the mma type is // FP8 as the fused implementation only supports fast acc at the moment. @@ -581,8 +577,8 @@ bool verify(const Options &options) { typename GemmRef::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, - {tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref} + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} }; // Run the gemm where the scaling is performed outside of the kernel. @@ -594,11 +590,9 @@ bool verify(const Options &options) { CUTLASS_CHECK(gemm_ref.run()); // compare_reference - tensor_D.sync_host(); - tensor_ref_D.sync_host(); - const ElementD epsilon(1e-2f); - const ElementD non_zero_floor(1e-4f); - bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor); + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); return passed; } diff --git a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt index 5ddfbd2e6e..a9753ed100 100644 --- a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -55,5 +55,16 @@ cutlass_example_add_executable( TEST_SCALE_ZERO_GROUPED TEST_SCALE_RESIDUE TEST_SCALE_ZERO_RESIDUE - TEST_ALPHA_BETA + # TEST_ALPHA_BETA + ) + +cutlass_example_add_executable( + 55_hopper_int4_fp8_gemm + 55_hopper_int4_fp8_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_RESIDUE + # TEST_ALPHA_BETA ) diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 8c393a6b75..07265f0d7e 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -11,6 +11,8 @@ This first version only supports mixed type GEMMs using TMA. While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type. +The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now. + We are currently optimizing the following cases: 1. Memory bound cases for all types diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp new file mode 100644 index 0000000000..294d426135 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include "cutlass/float8.h" + +namespace cutlass +{ +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; +} diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index 5181678ca7..51ce970dbd 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -32,7 +32,7 @@ /*! \file \brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. - This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA + This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA warp-specialized cooperative kernel. The new feature showcased in this example is on-the-fly modification of TMA descriptors to move between batches (represented by l). @@ -547,3 +547,4 @@ int main(int argc, char const **args) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index a26d904dcc..d57e1deea5 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -91,9 +91,9 @@ using namespace cute; using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group -using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand -using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand -using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using ElementC = cutlass::half_t; // Element type for C and D matrix operands #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/examples/59_ampere_gather_scatter_conv/CMakeLists.txt b/examples/59_ampere_gather_scatter_conv/CMakeLists.txt index e7f164003d..ce22cd1f37 100644 --- a/examples/59_ampere_gather_scatter_conv/CMakeLists.txt +++ b/examples/59_ampere_gather_scatter_conv/CMakeLists.txt @@ -26,6 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +if (NOT MSVC) + cutlass_example_add_executable( 59_ampere_gather_scatter_conv ampere_gather_scatter_conv.cu @@ -34,3 +36,5 @@ cutlass_example_add_executable( if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND) target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX) endif() + +endif() diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu new file mode 100644 index 0000000000..8bb14b4556 --- /dev/null +++ b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM + Top-K + Softmax fusion + + This example illustrates how to use the LinCombTopKSoftmaxCol EVT node to fuse + Top-K and Softmax into the GEMM epilogue, with certain assumptions made. + + Those assumptions are as: + 1. Fusion is over the N dimension. + 2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be + compiled to support both.) + 3. The GEMM tile shape along N is greater than or equal to problem size + along N. + + + The example runs the fused GEMM kernel, along with a standard unfused host reference, and + manually performs Top-K and softmax, and compares the error between tensors. + + Note that some numerical error (smaller than 1e-5) is to be expected, but this is true + in most efficient reduction kernels, because floating point addition is not necessarily + associative. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +static constexpr int TopK = 2; +static constexpr bool EnableTopKSoftmax = TopK > 1; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = void; +using LayoutC = cutlass::layout::RowMajor; +constexpr int AlignmentC = 1; + +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for output +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of output in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + +// Top-K + Softmax fusion operation +using FusionOperation = std::conditional_t, + typename cutlass::epilogue::fusion::LinearCombination +>; + +// The fusion op only allows for epilogue tiles matching the mainloop tile. +using EpilogueTileType = decltype(cute::take<0,2>(TileShape{})); + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideD stride_D; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + int iterations = 1000; + int m = 16, n = 8, k = 64, l = 1; + double eps = 1e-5; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("eps", eps); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "61_hopper_gemm_with_topk_and_softmax\n\n" + << " Hopper FP8 GEMM with Top-K and softmax fusion.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --eps= Threshold of numerical verification. Default: 1e-5.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "61_hopper_gemm_with_topk_and_softmax" << " --m=16 --n=8 --k=1024 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + float alpha() const { + return 1.f / static_cast(k); + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + cutlass::reference::host::TensorFillRandomUniform( + view, seed, /* max = */ 1, /* min = */ -1, /* bits = */ 2); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {options.alpha(), 0.f}, // alpha, beta + nullptr, stride_D, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + unused_t, + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.D = D; + epilogue_params.alpha = options.alpha(); + epilogue_params.beta = 0.f; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + if constexpr (EnableTopKSoftmax) { + // top-K + softmax + for (int i = 0; i < options.m; ++i) { + + // Find Top-K + cutlass::Array top_k; + top_k.fill(-cutlass::platform::numeric_limits::infinity()); + for (int j = 0; j < options.n; ++j) { + auto val = static_cast(tensor_ref_D.host_view().ref().at({i, j})); + for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) { + if (val > top_k[top_k_idx]) { + // Shift down + for (int l = TopK - 1; l > top_k_idx; --l) { + top_k[l] = top_k[l - 1]; + } + top_k[top_k_idx] = val; + break; + } + } + } + + // This formulation of top-K + softmax only works when it is + // guaranteed that none of the top-K elements are repeated! + // If this is the case, the device kernel can also make mistakes, because + // A. Once the top-K values are reduced, and the operation is being applied, + // there is no way to tell repeated elements apart, so none are masked. + // B. The softmax sum of exps will be incorrect (because the repeated elements + // are not repeated in it.) + + ElementAccumulator max = top_k[0]; + ElementAccumulator sum = ElementAccumulator(0.f); + for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) { + sum = sum + cutlass::fast_exp(top_k[top_k_idx] - max); + } + + for (int j=0; j < options.n; ++j) { + auto val = tensor_ref_D.host_view().ref().at({i, j}); + if (val < top_k[TopK - 1]) { + tensor_ref_D.host_view().ref().at({i, j}) = static_cast(0.f); + } else { + // Softmax + auto softmax_val = cutlass::fast_exp(val - max) / sum; + tensor_ref_D.host_view().ref().at({i, j}) = static_cast(softmax_val); + } + } + } + } + + // compare_reference + tensor_D.sync_host(); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + tensor_D.host_view(), + tensor_ref_D.host_view()); + bool passed = err < options.eps; + + if (options.m <= 32 && options.n <= 32) { + std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n"; + std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n"; + } + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl; + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt b/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt new file mode 100644 index 0000000000..7d9160a733 --- /dev/null +++ b/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 61_hopper_gemm_with_topk_and_softmax + 61_hopper_gemm_with_topk_and_softmax.cu + ) diff --git a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu new file mode 100644 index 0000000000..c3f1ce709a --- /dev/null +++ b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu @@ -0,0 +1,596 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Sparse GEMM example. + + This example demonstrates how to construct and run a structured sparse GEMM kernel + on NVIDIA Hopper architecture. + +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel +using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy + +using ProblemShape = Shape; + +// Sparse kernel setup + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference (dense) kernel setup + +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapeRef, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShapeRef, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopRef, + CollectiveEpilogue +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +// Layouts +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// Layouts for reference (non-sparse) tensors +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; + +using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +// Offline compressor kernel +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm90>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +ProblemShape problem_shape; + +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_A_compressed; +cutlass::DeviceAllocation block_E; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(5120), n(4096), k(16384), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "62_hopper_sparse_gemm\n\n" + << " Hopper Sparse GEMM example.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM (batch size)\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Make A structured sparse by replacing elements with 0 and compress it +bool sparsify_and_compress() +{ + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + block_A_compressed.reset(M * KC * L); + block_E.reset(ME * KE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + + // Random sparsification is performed on host + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast(seed + 2024)); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + problem_shape, + { block_A.get(), + stride_A, + block_A_compressed.get(), + block_E.get() }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(Options const& options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + auto [M, N, K, L] = problem_shape; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // Allocate memory for tensors + block_A.reset(M * K * L); + block_B.reset(N * K * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_D_ref.reset(M * N * L); + + // Fill input tensors with data + initialize_block(block_A, seed + 2021); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2023); + + // Replace 0 in A with 1 to avoid metadata changes + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + if (!sparsify_and_compress()) { + return false; + }; + + // Build the compressed/metadata layouts + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + + return true; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments make_args(Options const& options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D.get(), stride_D } + }; + + return arguments; +} + +typename GemmRef::Arguments make_args_ref(Options const& options) +{ + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A.get(), stride_A, block_B.get(), stride_B }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D_ref.get(), stride_D } + }; + + return arguments; +} + +template +void print_device_tensor(cute::Tensor const& t) +{ + // Assumes size = cosize, i.e. compact tensor + std::vector data_host(t.size()); + cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size()); + auto t_host = cute::make_tensor(data_host.data(), t.layout()); + cute::print_tensor(t_host); +} + +bool verify(Options const& options) { + CUDA_CHECK(cudaDeviceSynchronize()); + + bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size()); + +#if 0 + if (!passed) { + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A)); + cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed)); + cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E)); + cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast(layout_E))); + cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B)); + cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C)); + cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D)); + cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D)); + } +#endif + + return passed; +} + +template +struct Runner +{ + using Arguments = typename Gemm::Arguments; + + Runner(Arguments args): arguments(args) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + workspace.reset(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + } + + void run() { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + + void benchmark(Options const& options) { + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + run(); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double gflops = options.gflops(avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << gflops << std::endl; + } + } + + Gemm gemm; + Arguments arguments; + cutlass::device_memory::allocation workspace; +}; + +/// Execute the example (verification and timing) +void run(Options &options) { + bool init = initialize(options); + if (!init) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + Runner gemm(make_args(options)); + Runner gemm_ref(make_args_ref(options)); + + gemm.run(); + gemm_ref.run(); + + bool passed = verify(options); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(EXIT_FAILURE); + } + + std::cout << "Sparse GEMM:" << std::endl; + gemm.benchmark(options); + + std::cout << "Dense GEMM:" << std::endl; + gemm_ref.benchmark(options); +} + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) { + std::cerr << "This example requires CUDA 12.2 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + run(options); +#endif + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/62_hopper_sparse_gemm/CMakeLists.txt b/examples/62_hopper_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..cf55da4552 --- /dev/null +++ b/examples/62_hopper_sparse_gemm/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Sparse kernel in this example triggers an ICE in gcc 7.5 +if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) +cutlass_example_add_executable( + 62_hopper_sparse_gemm + 62_hopper_sparse_gemm.cu + ) +endif() diff --git a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu new file mode 100644 index 0000000000..03c54a8ee9 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper FP8 GEMM + L2 Weight Prefetch + + This example implements a non-persistent warp-specialized GEMM kernel for the Hopper + architecture with programmatic dependent launch (PDL) enabling prefetching weights into + L2 cache. + + For more information about dependent launch refer to the CUDA programming guide: + https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization + + In some cases, PDL can result in a window where a previous kernel is not actively utilizing + DRAM, and the next kernel sits idle until the previous finishes. During this window, the next + kernel can begin loading a non-dependent operand (i.e. weights in a linear projection are + typically static) and cache it in L2. + + The kernel and collective mainloop assume operand `A` corresponds to weights and operand `B` + corresponds to activations (so we can have very small batch/token count). + After initialization, the prefetch warp starts loading K tiles of `A` into an unused portion + of shared memory, and loads up to half of all K tiles that the same CTA would eventually load. + The exact number of K tiles loaded is determined by `args.mainloop.prefetch_ratio` \in + [0.0, 1.0]. Smaller values result in less prefetching, and larger values result in more. + Negative values result in a "best-effort" prefetch, meaning prefetcher will stop issuing weight + loads as soon as the activation DMA warp starts loading (as soon as it is signaled that the + previous kernel has flushed its memory.) + + The DMA warp responsible for loading `A` will also begin loading K tiles until it fills up + the available shared memory. + The DMA warp responsible for loading `B` will wait until activations are flushed to global + memory by the preceding kernel. + + Another mainloop parameter, `args.mainloop.overlap_ratio` \in [0.0, 1.0] determines how early + the next kernel (the one doing the prefetch) is launched. Smaller values result in greater + overlap, and larger values result in smaller overlap. Negative values disable PDL completely, + meaning there will be no overlap. This will make prefetch ineffective. + + These two runtime parameters should be tuned per problem size and GEMM config combination, and + if feasible, per-operation in an entire layer or model. + + NOTE: you must build this target with the following flag to enable Grid Dependency Control + instructions (GDC) in CUTLASS: + - CUTLASS_ENABLE_GDC_FOR_SM90 + + To lock persistence mode, power (350W), clocks (1005MHz) for evaluation (assumes device 0 and H100) + + $ sudo nvidia-smi -pm 1 -i 0 + + $ sudo nvidia-smi -i 0 -pl 350 + + $ sudo nvidia-smi -i 0 -lgc 1005 + + Example: + + $ mkdir build && cd build + + $ cmake .. -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1 + + $ cd examples/63_hopper_gemm_with_weight_prefetch + + $ make + + $ ./63_hopper_gemm_with_weight_prefetch --p=0.5 --o=0.5 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp" + +#include "helper.h" +#include "gemm_with_weight_prefetch_commandline.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size +// Cluster_N > 1 is not supported yet. +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + double eff_bw; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + double eff_bw = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), eff_bw(eff_bw), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + initialize_tensor(tensor_C.host_view(), seed + 2024); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + + arguments.mainloop.overlap_ratio = options.overlap_ratio; + arguments.mainloop.prefetch_ratio = options.prefetch_ratio; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0)); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0)); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0); + result.gflops = options.gflops(avg_runtime_s); + result.eff_bw = options.effective_bandwidth(avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD)); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + std::cout << " Effective bandwidth: " << result.eff_bw << " GB/s" << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} diff --git a/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt new file mode 100644 index 0000000000..f48673241a --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include_directories( + . +) + +cutlass_example_add_executable( + 63_hopper_gemm_with_weight_prefetch + 63_hopper_gemm_with_weight_prefetch.cu + ) diff --git a/examples/63_hopper_gemm_with_weight_prefetch/README.md b/examples/63_hopper_gemm_with_weight_prefetch/README.md new file mode 100644 index 0000000000..5dac1cc6c2 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/README.md @@ -0,0 +1,82 @@ +# GEMM with L2 weight prefetch + +A non-persistent warp specialized GEMM directed at low latency inference. + +The kernel can optionally prefetch a portion of weights (operand `A`) into L2 cache while the +rest of the warps are waiting on the previous kernel to finish writing and flush its memory. +An example of this is normalization or reduction kernels that are immediately followed by a GEMM. + +It exposes two runtime parameters: +1. `overlap_ratio`: how early `griddepcontrol.launch_dependent_grids` is issued. + Default is `0.5`, meaning after approximately half of K tiles are loaded by DMA warps. +2. `prefetch_ratio`: what percentage of K tiles to prefetch. + Default is `-1.0`, meaning prefetching will stop as soon as other DMA warps are past + `griddepcontrol`. + +It is highly recommended to auto-tune these parameters per GEMM and according to some end to end +runtime (either an entire transformer layer or multiple, but probably not the entire model.) + +TMA loads use non-default cache hints: `A` (weights) are loaded with `EvictFirst`, and `B` (activation) +is loaded with `EvictLast`. + +## Getting started +To use this kernel in your own target, add this directory to your includes, and include the +following headers from this example: + +```cxx +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp" +``` + +And then use either one of the new kernel schedules: + +```cxx +// Without separate warps for A and B +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetch; + +// With separate warps for A and B +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA; +``` + +The kernel with separate warps for A and B ( +`KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA`) +is expected to be more performant than the other, especially since it allows the kernel to load +weights into shmem ahead of the `griddepcontrol`. + +As for other GEMM parameters, Thread Block Cluster larger than 1 CTA are not yet supported, and +obviously the kernel layer implementation is warp specialized and uses the TMA, and other kernel +layers or collectives require reimplementation. + +## Example + +Using the example is mostly straightforward. +Just build, and run with your choice of `MNK`: + +```bash +./63_hopper_gemm_with_weight_prefetch --m=8192 --n=1 --k=8192 +``` + +You can also disable the overlap or try different overlap and prefetch ratios and see the +difference: + +```bash +echo "Without overlap and prefetch" +./63_hopper_gemm_with_weight_prefetch --o=-1.0 --p=-1.0 + +echo "Overlap ratio of 0.5, best effort prefetch" +./63_hopper_gemm_with_weight_prefetch --o=0.5 --p=-1.0 + +echo "Overlap ratio of 0.8, prefetch ratio of 0.7" +./63_hopper_gemm_with_weight_prefetch --o=0.8 --p=0.7 +``` + +However, note that the example still runs a single GEMM, and most of the performance improvement +is expected in end to end applications. + + +## Limitations +* The parameter defaults are typically not good choices, especially `prefetch_ratio`. + When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a + memory barrier before issuing every single TMA load, and in many cases this will slow down + prefetching to the point of being almost ineffective. diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp new file mode 100644 index 0000000000..57365a8b36 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" + +namespace cutlass::gemm::collective { + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000..37369176f9 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cutlass::gemm { + +// Standard non-persistent kernel with a single producer warp, and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp is waiting on griddepcontrol. +// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and +// according to prefetch ratio. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { }; + +// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not +// wait on griddepcontrol and loads immediately. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { }; + +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch +> +struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +} // namespace cutlass::gemm diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000..710224d78c --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -0,0 +1,867 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/arch/grid_dependency_control.h" + +#include "dispatch_policy_extra.hpp" + +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedWithPrefetch, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + static constexpr int PrefetchStages = 4; + static constexpr int PrefetchInitialStages = 1; + // This determines how much shmem we set aside for prefetch. + // We don't reuse anything loaded by prefetcher, so we can keep + // loading into the same place -- there will be a conflict when + // writing, but it doesn't affect performance as much as the doors + // that this opens. + static constexpr int PrefetchStagesActual = 1; + using PrefetcherPipeline = cutlass::PrefetchPipeline; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages); + static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages); + + using PrefetchSmemLayoutA = decltype(make_layout(make_shape( + cute::Int(SmemLayoutA{})>{}, + cute::Int(SmemLayoutA{})>{}, + cute::Int{}))); + + static constexpr auto prefetch_smem_size = cute::cosize_v; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Defined outside the class where it's used, to work around MSVC issues + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned smem_prefetch; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + PrefetcherPipelineStorage prefetcher_pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.overlap_ratio, + args.prefetch_ratio + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (args.overlap_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + if (args.prefetch_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + return true; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // We have to wait on dependent grids because of B. + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + // Don't wait on dependent grids when loading `A`, because + // we assume `A` (weights) are static. + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + prefetch_MK( + Params const& mainloop_params, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; + float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; + int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); + prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages); + + Tensor sA = make_tensor( + make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + uint32_t prefetcher_stage = 0; + uint32_t prefetcher_phase = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) { + + if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) { + break; + } + + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages); + using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; + BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); + + int write_stage = 0; + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + ++k_tile_iter; + + prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase); + } + prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp b/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp new file mode 100644 index 0000000000..6be87768ee --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float overlap_ratio = 0.5f, prefetch_ratio = 0.5f; + int iterations = 1000; + int n = 64, m = 1280, k = 8192, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f); + cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "63_hopper_gemm_with_weight_prefetch\n\n" + << " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --p= Prefetch ratio\n" + << " --o= Overlap ratio\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "63_hopper_gemm_with_weight_prefetch" << + " --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + /// Compute effective bandwidth in GB/sec + double effective_bandwidth( + double runtime_s, + size_t bytes_a, + size_t bytes_b, + size_t bytes_c, + size_t bytes_d + ) const + { + static double const kBytesPerGiB = double(1ull << 30); + + double bytes_in = + (double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A + (double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B + (beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C + double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D + + double gb_total = (bytes_in + bytes_out) / kBytesPerGiB; + return gb_total / runtime_s; + } +}; diff --git a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000..6e33d8fc62 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +// GEMM + Prefetch for the A tensor + (optional) split DMA warps +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v + > +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr bool SplitWarps = cute::is_same_v; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) PrefetcherPipelineStorage prefetcher; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + // Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK. + // Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused. + // Both modes use Warp1 to prefetch. + enum class ProducerWarpRole { + Warp0 = 0, + PrefetchMK = 1, + Warp2 = 2, + UnusedWarp = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + if (warp_group_role == WarpGroupRole::Producer && ( + producer_warp_role == ProducerWarpRole::Warp0 || + producer_warp_role == ProducerWarpRole::Warp2)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + bool should_prefetch = params.mainloop.prefetch_ratio > 0; + using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline; + typename PrefetcherPipeline::Params prefetcher_pipeline_params; + prefetcher_pipeline_params.num_prefetchers = 1; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + prefetcher_pipeline_params.should_prefetch = should_prefetch; + prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk; + } + PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + // Non-prefetcher warps arrive and wait, + // Prefetcher warp can go ahead without waiting. + cute::cluster_arrive_relaxed(); + if (warp_group_role != WarpGroupRole::Producer || + producer_warp_role != ProducerWarpRole::PrefetchMK) { + cute::cluster_wait(); + } + return [] () {}; + } + else { + // __syncthreads() but only for non prefetcher warps + if (should_prefetch) { + + // Use a named barrier to let the prefetcher warp start loading into the L2 + // without waiting to sync with all other warps. + // All other warps need to sync because the mainloop pipeline init + // should be visible to all of them. + // Prefetcher has its own barriers, and the only warps it would need to sync + // with would be the DMA warps. + using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier; + auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z, + /*reserved_named_barriers_*/ 14); + // Prefetcher warp doesn't arrive on this barrier. + auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp, + /*reserved_named_barriers_*/ 15); + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + __syncwarp(); + prefetcher_arrive_barrier.arrive(); + } + else if (warp_group_role == WarpGroupRole::Producer) { + prefetcher_arrive_barrier.arrive_and_wait(); + cluster_arrive_barrier.arrive_and_wait(); + } + else { + prefetcher_arrive_barrier.arrive(); + cluster_arrive_barrier.arrive_and_wait(); + } + } else { + __syncthreads(); + } + return [] () {}; + } + } (); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + if (producer_warp_role == ProducerWarpRole::Warp0) { + if constexpr(SplitWarps) { + collective_mainloop.load_NK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + else { + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + // Ensure warp is converged before issuing epilogue loads + __syncwarp(); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) { + collective_mainloop.load_MK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) { + collective_mainloop.prefetch_MK( + params.mainloop, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp b/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp new file mode 100644 index 0000000000..7abd39ccfc --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/container/array.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// MSVC work-around +template +struct PrefetcherPipelineSharedStorage { + using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier; + using Barrier = cutlass::arch::ClusterBarrier; + + TransactionBarrier tma_barrier[Stages]; + Barrier producer_ready_barrier; +}; + +} // end namespace detail + +using namespace cute; + +// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction +// barrier providing control over the number of concurrent outstanding TMA loads. +// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset. +// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks +// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads. +template +class PrefetchPipeline { +public : + static constexpr uint32_t Stages = Stages_; + using SharedStorage = detail::PrefetcherPipelineSharedStorage; + + using TransactionBarrier = typename SharedStorage::TransactionBarrier; + using Barrier = typename SharedStorage::Barrier; + using PrefetcherBarrierType = typename TransactionBarrier::ValueType; + + struct Params { + uint32_t transaction_bytes = 0; + uint32_t num_prefetchers = 1; + bool should_prefetch = false; + }; + + // Constructor + CUTLASS_DEVICE + PrefetchPipeline(SharedStorage& storage, Params params) + : params_(params) + , tma_barrier_ptr_(&storage.tma_barrier[0]) + , producer_ready_barrier_ptr_(&storage.producer_ready_barrier) { + + int lane_predicate = cute::elect_one_sync(); + if (params.should_prefetch && lane_predicate) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; ++i) { + tma_barrier_ptr_[i].init(params.num_prefetchers); + } + producer_ready_barrier_ptr_[0].init(1); + } + } + + CUTLASS_DEVICE + void producer_arrive() { + if (params_.should_prefetch) { + producer_ready_barrier_ptr_[0].arrive(); + } + } + + CUTLASS_DEVICE + bool have_producers_arrived() { + if (params_.should_prefetch) { + uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0); + auto barrier_status = static_cast(barrier_status_); + if (barrier_status == BarrierStatus::WaitDone) { + return true; // exit prefetcher loop + } + return false; + } + return true; + } + + CUTLASS_DEVICE + void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) { + if (params_.should_prefetch) { + if (should_wait) { + tma_barrier_ptr_[stage].wait(phase ^ 1); + } + tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); + } + } + + CUTLASS_DEVICE + void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) { + if (params_.should_prefetch) { + stage++; + if (stage == Stages) { + stage = 0; + phase ^= 1; + } + } + } + + CUTLASS_DEVICE + void prefetcher_tail(uint32_t stage, uint32_t phase) { + if (params_.should_prefetch) { + // Wait on any already-issued loads + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < stage; ++i) { + tma_barrier_ptr_[i].wait(phase); + } + } + } + + CUTLASS_DEVICE + PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) { + return reinterpret_cast(&tma_barrier_ptr_[stage]); + } + +private : + TransactionBarrier* tma_barrier_ptr_ = nullptr; + Barrier* producer_ready_barrier_ptr_ = nullptr; + Params params_; + +}; + +} // end namespace cutlass diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 9cb125d988..6486d71435 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -140,6 +140,9 @@ foreach(EXAMPLE 57_hopper_grouped_gemm 58_ada_fp8_gemm 59_ampere_gather_scatter_conv + 61_hopper_gemm_with_topk_and_softmax + 62_hopper_sparse_gemm + 63_hopper_gemm_with_weight_prefetch ) add_subdirectory(${EXAMPLE}) diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu index d370320b1b..87ad873ce6 100644 --- a/examples/cute/tutorial/tiled_copy.cu +++ b/examples/cute/tutorial/tiled_copy.cu @@ -186,8 +186,8 @@ int main(int argc, char** argv) return -1; } // Equivalent check to the above - if (not weakly_compatible(block_shape, tensor_shape)) { - std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl; + if (not evenly_divides(tensor_shape, block_shape)) { + std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl; return -1; } diff --git a/include/cute/algorithm/clear.hpp b/include/cute/algorithm/clear.hpp index 1c7dd5a334..0b3a8eaa1d 100644 --- a/include/cute/algorithm/clear.hpp +++ b/include/cute/algorithm/clear.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::fill namespace cute { diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index b2be11717f..9d080116da 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -31,12 +31,14 @@ #pragma once #include - -#include -#include - -#include +#include +#include // cute::logical_divide +#include // cute::Swizzle +#include // cute::get_nonswizzle_portion +#include // cute::Tensor #include +#include +#include namespace cute { diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index da03bfbd11..2c91ce6f45 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -434,8 +434,8 @@ cooperative_gemm(uint32_t thread_idx, static_assert(is_convertible_v>, TypeC>, "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); - static constexpr bool compat = weakly_compatible(tile_shape(TiledMMA{}), - make_shape(size<0>(sA), size<0>(sB), size<1>(sA))); + static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); if constexpr (compat) { detail::cooperative_gemm_no_predication( thread_idx, tiled_mma, alpha, sA, sB, beta, sC, diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 2a37995eea..c2decd15d7 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -30,14 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::TrivialPredTensor +#include // cute::Copy_Atom namespace cute { diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp index 8e7a58a5bc..ef80d018d7 100644 --- a/include/cute/algorithm/functional.hpp +++ b/include/cute/algorithm/functional.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::max, cute::min +#include // cute::conj /** C++14 extensions */ diff --git a/include/cute/algorithm/prefetch.hpp b/include/cute/algorithm/prefetch.hpp index 0d638ab58f..c39f63acdd 100644 --- a/include/cute/algorithm/prefetch.hpp +++ b/include/cute/algorithm/prefetch.hpp @@ -30,11 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom namespace cute { diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index 616960a54a..c87ce682d1 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -44,7 +44,7 @@ /// Code guidelines and style preferences: /// /// For perfect forwarding, don't use std::forward, because it may not -/// be defined in device code when compiling with NVRTC. Instead, use +/// be defined in device code when compiling with NVRTC. Instead, use /// `static_cast(parameter_name)`. /// /// CuTe generally does not bother forwarding functions, as @@ -52,24 +52,9 @@ /// /// Throughout CUTLASS, cute::make_tuple always needs to be called /// namespace-qualified, EVEN If inside the cute namespace and/or in -/// scope of a "using namespace cute" declaration. Otherwise, the +/// scope of a "using namespace cute" declaration. Otherwise, the /// compiler may select std::make_tuple instead of cute::make_tuple, -/// due to argument-dependent lookup. Two problems may result from -/// that. -/// -/// 1. Functions have an unexpected return type (std::tuple instead of -/// cute::tuple), so functions that take cute::tuple parameters -/// fail to compile (generally inside functions that have template -/// parameters expected to be cute::tuple). -/// -/// 2. std::tuple does not have the required __host__ __device__ -/// markings, so the CUDA compiler complains if you use it in -/// device code. -/// -/// cute::make_tuple will occur more often than std::make_tuple would -/// in modern C++ code, because cute::tuple's design deprioritizes -/// correct operation of CTAD (constructor template argument -/// deduction) in favor of implementation simplicity. +/// due to argument-dependent lookup. namespace cute { @@ -145,6 +130,8 @@ transform_apply(T&& t, F&& f, G&& g) } else { return g(f(static_cast(t))); } + + CUTE_GCC_UNREACHABLE; } template @@ -157,6 +144,8 @@ transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) } else { return g(f(static_cast(t0), static_cast(t1))); } + + CUTE_GCC_UNREACHABLE; } template @@ -169,6 +158,8 @@ transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) } else { return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); } + + CUTE_GCC_UNREACHABLE; } // @@ -401,71 +392,36 @@ filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) namespace detail { -// This impl compiles much faster than cute::apply and variadic args -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&&, V&& v, F&&, seq<>) -{ - return v; -} - -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return f(static_cast(v), get(static_cast(t))); -} - -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))); -} - -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))); -} +template +struct FoldAdaptor { + template + CUTE_HOST_DEVICE constexpr auto operator|(X&& x) { + auto r = fn_(val_, static_cast(x)); + return FoldAdaptor{fn_, r}; + } + Fn fn_; + Val val_; +}; -template +template CUTE_HOST_DEVICE constexpr auto -fold(T&& t, V&& v, F&& f, seq) +fold(T&& t, V const& v, F&& f, seq) { - return f(f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))); + return (FoldAdaptor{f,v} | ... | get(static_cast(t))).val_; } -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return fold(static_cast(t), - f(f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), - f, - seq{}); -} } // end namespace detail template CUTE_HOST_DEVICE constexpr auto -fold(T&& t, V&& v, F&& f) +fold(T&& t, V const& v, F&& f) { if constexpr (is_tuple>::value) { - return detail::fold(static_cast(t), - static_cast(v), - f, - tuple_seq{}); + return detail::fold(static_cast(t), v, f, tuple_seq{}); } else { - return f(static_cast(v), static_cast(t)); + return f(v, static_cast(t)); } CUTE_GCC_UNREACHABLE; @@ -477,10 +433,7 @@ auto fold_first(T&& t, F&& f) { if constexpr (is_tuple>::value) { - return detail::fold(static_cast(t), - get<0>(static_cast(t)), - f, - make_range<1,tuple_size>::value>{}); + return detail::fold(static_cast(t), get<0>(t), f, make_range<1,tuple_size>::value>{}); } else { return t; } @@ -536,13 +489,23 @@ CUTE_HOST_DEVICE constexpr auto take(T const& t) { - return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + if constexpr (E == -1) { + if constexpr (is_tuple::value) { + return take::value>(t); + } else { + return take(t); + } + } else + if constexpr (B <= E) { + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; } -// // Select tuple elements with given indices. -// - template CUTE_HOST_DEVICE constexpr auto @@ -551,19 +514,6 @@ select(T const& t) return cute::make_tuple(get(t)...); } -template -CUTE_HOST_DEVICE constexpr -auto -select(T const& t, Indices const& indices) -{ - if constexpr (is_tuple::value) { - return cute::transform(indices, [&t](auto i) { return select(t, i); }); - } else { - static_assert(is_static::value, "Order must be static"); - return get(t); - } -} - // Wrap non-tuples into rank-1 tuples or forward template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index 27a34d7773..8fff51be8e 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -150,7 +150,7 @@ CUTE_DEVICE dim3 cluster_shape() } // Get 1D ctaid in a cluster. -CUTLASS_DEVICE uint32_t block_rank_in_cluster() +CUTE_DEVICE uint32_t block_rank_in_cluster() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t rank; @@ -162,7 +162,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster() } // Set the destination block-ID in cluster for a given SMEM Address -CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +CUTE_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t result; diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp new file mode 100644 index 0000000000..84d7779a34 --- /dev/null +++ b/include/cute/arch/config.hpp @@ -0,0 +1,50 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTLASS_ARCH_MMA_SMxx_ENABLED + +// TMA instructions +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +// STSM +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cute/arch/copy_sm50.hpp b/include/cute/arch/copy_sm50.hpp index 9cf0efcdf5..925d9ebe37 100644 --- a/include/cute/arch/copy_sm50.hpp +++ b/include/cute/arch/copy_sm50.hpp @@ -40,8 +40,8 @@ namespace cute { - -struct SM50_Shuffle_U32_2x2Trans +// Shuffle data between thread pair (0, 1), (2, 3), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR1 { using SRegisters = uint32_t[2]; using DRegisters = uint32_t[2]; @@ -68,5 +68,31 @@ struct SM50_Shuffle_U32_2x2Trans } }; +// Shuffle data between thread pair (0, 4), (1, 5), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR4 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = threadIdx.x & 4 ? src0 : src1; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 4); + + // Replace detination register with shuffle result. + if (threadIdx.x & 0x4) { + dst0 = y0; + } + else { + dst1 = y0; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + } // end namespace cute diff --git a/include/cute/arch/copy_sm90.hpp b/include/cute/arch/copy_sm90.hpp index e5684ec469..bcb3b7d19c 100644 --- a/include/cute/arch/copy_sm90.hpp +++ b/include/cute/arch/copy_sm90.hpp @@ -30,21 +30,10 @@ **************************************************************************************************/ #pragma once -#include - +#include // CUTE_HOST_DEVICE +#include // CUTE_ARCH_TMA_SMxx_ENABLED #include -// Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -# define CUTE_ARCH_STSM_SM90_ENABLED -# define CUTE_ARCH_TMA_SM90_ENABLED -#endif - -#if defined(CUTE_ARCH_TMA_SM90_ENABLED) && \ - ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) -# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED -#endif - namespace cute { diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 21e473ede9..25a252a8e7 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -30,6 +30,8 @@ **************************************************************************************************/ #pragma once +#include "cutlass/numeric_types.h" + #if !defined(__CUDACC_RTC__) #include #include @@ -37,6 +39,8 @@ #include +#include // cute::cast_smem_ptr_to_uint +#include // CUTE_ARCH_TMA_SMxx_ENABLED #include #include @@ -134,6 +138,10 @@ enum class SmemSwizzleBits : uint8_t { B128 = 3, }; +enum class SmemSwizzleBase : uint8_t { + SWIZZLE_BASE_16B = 0, +}; + enum class OOBFill : uint8_t { ZERO = 0, CONSTANT = 1, @@ -201,13 +209,21 @@ to_CUtensorMapDataType() { } inline CUtensorMapSwizzle -to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { +to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { switch (t) { - default: assert(false && "Unknown SmemSwizzleBits!"); - case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; - case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B; - case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B; - case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; + default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBits::DISABLE: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 32B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_128B; } } @@ -282,7 +298,7 @@ tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" :: "l"(gmem_int_desc), "l"(new_desc_addr)); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -295,15 +311,11 @@ tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); - uint64_t const smem_int64_desc = 0; - asm volatile ( - "cvt.u64.u32 %0, %1;" - :: "l"(smem_int64_desc), "r"(smem_int_desc)); asm volatile ( "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" - :: "l"(smem_int64_desc), "l"(new_desc_addr)); + :: "r"(smem_int_desc), "l"(new_desc_addr)); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -331,7 +343,6 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor :: "l"(smem_int64_desc), "r"(prob_shape[2])); // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) - // 4 LSBs are not included asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[1])); @@ -339,6 +350,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[2])); #else + // 4 LSBs are not included asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); @@ -347,7 +359,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); #endif #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -366,7 +378,7 @@ tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescripto "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" :: "l"(gmem_int_desc), "r"(smem_int_desc)); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -381,7 +393,7 @@ tma_descriptor_fence_release() #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -400,13 +412,8 @@ tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) : : "l"(gmem_int_desc) : "memory"); - asm volatile ( - "cvta.global.u64 %0, %0;" - : - : "l"(gmem_int_desc), "l"(gmem_int_desc) - : "memory"); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index 1851482119..fb33d63cad 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -32,8 +32,11 @@ #include +#include // CUTE_ARCH_TMA_SMxx_ENABLED #include #include +#include "cutlass/arch/synclog.hpp" + namespace cute { @@ -52,6 +55,7 @@ struct SM90_TMA_LOAD_1D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3}], [%2], %4;" @@ -97,6 +101,7 @@ struct SM90_TMA_LOAD_2D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4}], [%2], %5;" @@ -142,6 +147,7 @@ struct SM90_TMA_LOAD_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5}], [%2], %6;" @@ -187,6 +193,7 @@ struct SM90_TMA_LOAD_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" @@ -232,6 +239,7 @@ struct SM90_TMA_LOAD_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" @@ -355,6 +363,7 @@ struct SM90_TMA_LOAD_IM2COL_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" @@ -405,6 +414,7 @@ struct SM90_TMA_LOAD_IM2COL_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" @@ -455,6 +465,7 @@ struct SM90_TMA_LOAD_IM2COL_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" @@ -565,7 +576,7 @@ struct SM90_TMA_LOAD_IM2COL struct SM90_TMA_LOAD_MULTICAST_1D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { @@ -573,13 +584,14 @@ struct SM90_TMA_LOAD_MULTICAST_1D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4}], [%2], %3;" + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0) + "r"(crd0), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -590,7 +602,7 @@ struct SM90_TMA_LOAD_MULTICAST_1D struct SM90_TMA_LOAD_MULTICAST_2D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { @@ -598,13 +610,14 @@ struct SM90_TMA_LOAD_MULTICAST_2D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5}], [%2], %3;" + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1) + "r"(crd0), "r"(crd1), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -615,7 +628,7 @@ struct SM90_TMA_LOAD_MULTICAST_2D struct SM90_TMA_LOAD_MULTICAST_3D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { @@ -623,13 +636,14 @@ struct SM90_TMA_LOAD_MULTICAST_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5, %6}], [%2], %3;" + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1), "r"(crd2) + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -640,7 +654,7 @@ struct SM90_TMA_LOAD_MULTICAST_3D struct SM90_TMA_LOAD_MULTICAST_4D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { @@ -648,13 +662,14 @@ struct SM90_TMA_LOAD_MULTICAST_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5, %6, %7}], [%2], %3;" + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -665,7 +680,7 @@ struct SM90_TMA_LOAD_MULTICAST_4D struct SM90_TMA_LOAD_MULTICAST_5D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { @@ -673,13 +688,14 @@ struct SM90_TMA_LOAD_MULTICAST_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;" + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -690,39 +706,39 @@ struct SM90_TMA_LOAD_MULTICAST_5D struct SM90_TMA_LOAD_MULTICAST { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { - return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0); + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { - return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1); + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { - return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2); + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { - return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { - return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); } using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; @@ -744,6 +760,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" @@ -772,6 +789,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" @@ -800,6 +818,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" @@ -871,6 +890,7 @@ struct SM90_TMA_STORE_1D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" : @@ -893,6 +913,7 @@ struct SM90_TMA_STORE_2D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" : @@ -915,6 +936,7 @@ struct SM90_TMA_STORE_3D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" : @@ -937,6 +959,7 @@ struct SM90_TMA_STORE_4D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" : @@ -959,6 +982,7 @@ struct SM90_TMA_STORE_5D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" : @@ -1024,6 +1048,7 @@ struct SM90_TMA_STORE_IM2COL_3D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" " [%0, {%2, %3, %4}], [%1];" @@ -1047,6 +1072,7 @@ struct SM90_TMA_STORE_IM2COL_4D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" " [%0, {%2, %3, %4, %5}], [%1];" @@ -1070,6 +1096,7 @@ struct SM90_TMA_STORE_IM2COL_5D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" " [%0, {%2, %3, %4, %5, %6}], [%1];" @@ -1112,6 +1139,7 @@ struct SM90_TMA_STORE_IM2COL CUTE_HOST_DEVICE static void tma_store_fence() { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); asm volatile ("fence.proxy.async.shared::cta;"); #elif defined(__CUDA_ARCH__) CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -1122,6 +1150,7 @@ tma_store_fence() { CUTE_HOST_DEVICE static void tma_store_arrive() { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_tma_store_arrive(__LINE__); asm volatile("cp.async.bulk.commit_group;"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -1138,6 +1167,7 @@ tma_store_wait() { : : "n"(Count) : "memory"); + cutlass::arch::synclog_emit_tma_store_wait(__LINE__, Count); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -1157,6 +1187,7 @@ struct SM90_TMA_REDUCE_ADD_1D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group [%0, {%2}], [%1];" : @@ -1179,6 +1210,7 @@ struct SM90_TMA_REDUCE_ADD_2D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group [%0, {%2, %3}], [%1];" : @@ -1201,6 +1233,7 @@ struct SM90_TMA_REDUCE_ADD_3D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4}], [%1];" : @@ -1223,6 +1256,7 @@ struct SM90_TMA_REDUCE_ADD_4D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5}], [%1];" : @@ -1245,6 +1279,7 @@ struct SM90_TMA_REDUCE_ADD_5D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" : diff --git a/include/cute/arch/mma.hpp b/include/cute/arch/mma.hpp index 5bfda7463c..6e06114a6c 100644 --- a/include/cute/arch/mma.hpp +++ b/include/cute/arch/mma.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::fma +#include // cute::fma namespace cute { diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index d504bf39df..6ab29adc9d 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -32,7 +32,6 @@ #pragma once #include - #include // Config @@ -45,10 +44,12 @@ namespace cute { +namespace SM90 { + //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x4 TN -struct SM90_16x8x4_F64F64F64F64_TN +struct MMA_16x8x4_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[2]; @@ -73,7 +74,7 @@ struct SM90_16x8x4_F64F64F64F64_TN "d"(b0), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; @@ -81,7 +82,7 @@ struct SM90_16x8x4_F64F64F64F64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x8 TN -struct SM90_16x8x8_F64F64F64F64_TN +struct MMA_16x8x8_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[4]; @@ -106,7 +107,7 @@ struct SM90_16x8x8_F64F64F64F64_TN "d"(b0), "d"(b1), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; @@ -114,7 +115,7 @@ struct SM90_16x8x8_F64F64F64F64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x16 TN -struct SM90_16x8x16_F64F64F64F64_TN +struct MMA_16x8x16_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[8]; @@ -141,7 +142,7 @@ struct SM90_16x8x16_F64F64F64F64_TN "d"(b0), "d"(b1), "d"(b2), "d"(b3), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; @@ -149,7 +150,7 @@ struct SM90_16x8x16_F64F64F64F64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x4 TN -struct SM90_16x8x4_C64C64C64C64_TN +struct MMA_16x8x4_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[2]; @@ -175,28 +176,28 @@ struct SM90_16x8x4_C64C64C64C64_TN double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), b0.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), b0.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), b0.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), b0.imag(), @@ -207,7 +208,7 @@ struct SM90_16x8x4_C64C64C64C64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x8 TN -struct SM90_16x8x8_C64C64C64C64_TN +struct MMA_16x8x8_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[4]; @@ -234,28 +235,28 @@ struct SM90_16x8x8_C64C64C64C64_TN double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), a2.real(), a3.real(), b0.real(), b1.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), a2.imag(), a3.imag(), b0.real(), b1.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), b0.imag(), b1.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), a2.real(), a3.real(), b0.imag(), b1.imag(), @@ -266,7 +267,7 @@ struct SM90_16x8x8_C64C64C64C64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x16 TN -struct SM90_16x8x16_C64C64C64C64_TN +struct MMA_16x8x16_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[8]; @@ -296,7 +297,7 @@ struct SM90_16x8x16_C64C64C64C64_TN double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), a2.real(), a3.real(), a4.real(), a5.real(), a6.real(), a7.real(), @@ -304,7 +305,7 @@ struct SM90_16x8x16_C64C64C64C64_TN c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), a2.imag(), a3.imag(), a4.imag(), a5.imag(), a6.imag(), a7.imag(), @@ -312,7 +313,7 @@ struct SM90_16x8x16_C64C64C64C64_TN c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), @@ -320,7 +321,7 @@ struct SM90_16x8x16_C64C64C64C64_TN d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), a2.real(), a3.real(), a4.real(), a5.real(), a6.real(), a7.real(), @@ -331,17 +332,24 @@ struct SM90_16x8x16_C64C64C64C64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// +} + } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// #include #include +#include +#include // cute::size +#include // cute::is_static +#include // cute::half_t, cute::float_e4m3_t, cute::tfloat32_t, etc +#include // cute::is_same_v //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cute { -namespace GMMA { +namespace SM90::GMMA { template < class ElementA, @@ -370,73 +378,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x256x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x240x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x224x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x208x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x192x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x176x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x160x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x144x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x128x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x112x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x96x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x80x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x64x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x48x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x32x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x16x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x8x16_F16F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -450,73 +458,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -530,73 +538,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -610,73 +618,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -690,73 +698,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -776,73 +784,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x256x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x240x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x224x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x208x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x192x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x176x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x160x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x144x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x128x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x96x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x80x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x64x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x16x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x8x16_F32F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -854,73 +862,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -934,73 +942,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1014,73 +1022,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1094,73 +1102,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1174,73 +1182,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1254,73 +1262,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1342,73 +1350,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1422,73 +1430,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1502,73 +1510,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1582,73 +1590,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1678,12 +1686,11 @@ template < > CUTE_HOST_DEVICE constexpr auto -rs_op_selector() +ss_op_selector_sparse() { static_assert(is_static::value, "TileShape_MNK must be static."); static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); - static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); auto Tile_N = size<1>(TileShape_MNK{}); // F16 accumulator @@ -1691,76 +1698,76 @@ rs_op_selector() // Input A: half_t ; Input B: half_t if constexpr (is_same_v && is_same_v) { - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1771,76 +1778,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1851,76 +1858,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1931,76 +1938,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2011,76 +2018,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2097,76 +2104,76 @@ rs_op_selector() // Input A: half_t ; Input B: half_t if constexpr (is_same_v && is_same_v) { - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2175,76 +2182,76 @@ rs_op_selector() // Input A: bfloat16_t ; Input B: bfloat16_t else if constexpr (is_same_v && is_same_v) { - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2255,76 +2262,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2335,76 +2342,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2415,76 +2422,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2495,76 +2502,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2575,76 +2582,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2663,76 +2670,76 @@ rs_op_selector() if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2743,76 +2750,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2823,76 +2830,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2901,78 +2908,2726 @@ rs_op_selector() // Input A: uint8_t ; Input B: uint8_t else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8U8_RS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2990,7 +5645,7 @@ rs_op_selector() } } -} // end namespace GMMA +} // end namespace SM90::GMMA } // end namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index 1d6caba89d..a53a9748b4 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -48,8 +48,7 @@ namespace cute { // GMMA Descriptor and utilities // GMMA enums and utilities -namespace GMMA -{ +namespace SM90::GMMA { enum class LayoutType : uint8_t { INTERLEAVE = 0, @@ -81,7 +80,7 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { } #endif // !defined(__CUDACC_RTC__) -} // end namespace GMMA +} // end namespace SM90::GMMA union GmmaDescriptor { @@ -146,7 +145,7 @@ print(GmmaDescriptor const& t) printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); - printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); #endif // !defined(__CUDACC_RTC__) } diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp index aebb8fab5a..4dc01463b7 100644 --- a/include/cute/arch/mma_sm90_gmma.hpp +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -30,8 +30,10 @@ **************************************************************************************************/ #pragma once -#include -#include +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) # define CUTE_ARCH_MMA_SM90A_ENABLED @@ -47,6 +49,7 @@ void warpgroup_arrive() { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_arrive(__LINE__); asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); #else CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); @@ -60,6 +63,7 @@ warpgroup_wait() { static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); @@ -72,6 +76,7 @@ void warpgroup_commit_batch() { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_commit_batch(__LINE__); asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); #else CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); @@ -97,7 +102,7 @@ warpgroup_fence_operand(float& reg) { #endif } -namespace GMMA { +namespace SM90::GMMA { enum class Major { K = 0, @@ -114,7 +119,11 @@ enum class ScaleIn { One = 1 }; -} // namespace GMMA +enum class SparseSel { + Zero = 0, + One = 1 +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) @@ -127,7 +136,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F16F16F16_SS +struct MMA_64x8x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -141,6 +150,7 @@ struct SM90_64x8x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -156,7 +166,7 @@ struct SM90_64x8x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -170,7 +180,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F16F16F16_RS +struct MMA_64x8x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -187,6 +197,7 @@ struct SM90_64x8x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -202,7 +213,7 @@ struct SM90_64x8x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -216,7 +227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F16F16F16_SS +struct MMA_64x16x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -230,6 +241,7 @@ struct SM90_64x16x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -245,7 +257,7 @@ struct SM90_64x16x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -259,7 +271,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F16F16F16_RS +struct MMA_64x16x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -276,6 +288,7 @@ struct SM90_64x16x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -291,7 +304,7 @@ struct SM90_64x16x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -305,7 +318,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F16F16F16_SS +struct MMA_64x32x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -320,6 +333,7 @@ struct SM90_64x32x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -336,7 +350,7 @@ struct SM90_64x32x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -350,7 +364,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F16F16F16_RS +struct MMA_64x32x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -368,6 +382,7 @@ struct SM90_64x32x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -384,7 +399,7 @@ struct SM90_64x32x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -399,7 +414,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F16F16F16_SS +struct MMA_64x48x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -415,6 +430,7 @@ struct SM90_64x48x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -433,7 +449,7 @@ struct SM90_64x48x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -449,7 +465,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F16F16F16_RS +struct MMA_64x48x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -468,6 +484,7 @@ struct SM90_64x48x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -486,7 +503,7 @@ struct SM90_64x48x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -501,7 +518,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F16F16F16_SS +struct MMA_64x64x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -518,6 +535,7 @@ struct SM90_64x64x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -537,7 +555,7 @@ struct SM90_64x64x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -551,7 +569,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F16F16F16_RS +struct MMA_64x64x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -571,6 +589,7 @@ struct SM90_64x64x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -590,7 +609,7 @@ struct SM90_64x64x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -605,7 +624,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F16F16F16_SS +struct MMA_64x80x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -623,6 +642,7 @@ struct SM90_64x80x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -644,7 +664,7 @@ struct SM90_64x80x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -660,7 +680,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F16F16F16_RS +struct MMA_64x80x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -681,6 +701,7 @@ struct SM90_64x80x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -702,7 +723,7 @@ struct SM90_64x80x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -717,7 +738,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F16F16F16_SS +struct MMA_64x96x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -736,6 +757,7 @@ struct SM90_64x96x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -758,7 +780,7 @@ struct SM90_64x96x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -772,7 +794,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F16F16F16_RS +struct MMA_64x96x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -794,6 +816,7 @@ struct SM90_64x96x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -816,7 +839,7 @@ struct SM90_64x96x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -831,7 +854,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F16F16F16_SS +struct MMA_64x112x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -851,6 +874,7 @@ struct SM90_64x112x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -875,7 +899,7 @@ struct SM90_64x112x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -891,7 +915,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F16F16F16_RS +struct MMA_64x112x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -914,6 +938,7 @@ struct SM90_64x112x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -938,7 +963,7 @@ struct SM90_64x112x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -953,7 +978,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F16F16F16_SS +struct MMA_64x128x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -974,6 +999,7 @@ struct SM90_64x128x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -999,7 +1025,7 @@ struct SM90_64x128x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1013,7 +1039,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F16F16F16_RS +struct MMA_64x128x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1037,6 +1063,7 @@ struct SM90_64x128x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1062,7 +1089,7 @@ struct SM90_64x128x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1077,7 +1104,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F16F16F16_SS +struct MMA_64x144x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1099,6 +1126,7 @@ struct SM90_64x144x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1126,7 +1154,7 @@ struct SM90_64x144x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1142,7 +1170,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F16F16F16_RS +struct MMA_64x144x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1167,6 +1195,7 @@ struct SM90_64x144x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1194,7 +1223,7 @@ struct SM90_64x144x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1210,7 +1239,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F16F16F16_SS +struct MMA_64x160x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1233,6 +1262,7 @@ struct SM90_64x160x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1261,7 +1291,7 @@ struct SM90_64x160x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1277,7 +1307,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F16F16F16_RS +struct MMA_64x160x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1303,6 +1333,7 @@ struct SM90_64x160x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1331,7 +1362,7 @@ struct SM90_64x160x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1347,7 +1378,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F16F16F16_SS +struct MMA_64x176x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1371,6 +1402,7 @@ struct SM90_64x176x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1401,7 +1433,7 @@ struct SM90_64x176x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1417,7 +1449,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F16F16F16_RS +struct MMA_64x176x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1444,6 +1476,7 @@ struct SM90_64x176x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1474,7 +1507,7 @@ struct SM90_64x176x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1489,7 +1522,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F16F16F16_SS +struct MMA_64x192x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1514,6 +1547,7 @@ struct SM90_64x192x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1545,7 +1579,7 @@ struct SM90_64x192x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1559,7 +1593,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F16F16F16_RS +struct MMA_64x192x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1587,6 +1621,7 @@ struct SM90_64x192x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1618,7 +1653,7 @@ struct SM90_64x192x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1633,7 +1668,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F16F16F16_SS +struct MMA_64x208x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1659,6 +1694,7 @@ struct SM90_64x208x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1692,7 +1728,7 @@ struct SM90_64x208x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1708,7 +1744,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F16F16F16_RS +struct MMA_64x208x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1737,6 +1773,7 @@ struct SM90_64x208x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1770,7 +1807,7 @@ struct SM90_64x208x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1786,7 +1823,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F16F16F16_SS +struct MMA_64x224x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1813,6 +1850,7 @@ struct SM90_64x224x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1847,7 +1885,7 @@ struct SM90_64x224x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1863,7 +1901,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F16F16F16_RS +struct MMA_64x224x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1893,6 +1931,7 @@ struct SM90_64x224x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1927,7 +1966,7 @@ struct SM90_64x224x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1943,7 +1982,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F16F16F16_SS +struct MMA_64x240x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1971,6 +2010,7 @@ struct SM90_64x240x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2007,7 +2047,7 @@ struct SM90_64x240x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2023,7 +2063,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F16F16F16_RS +struct MMA_64x240x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2054,6 +2094,7 @@ struct SM90_64x240x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2090,7 +2131,7 @@ struct SM90_64x240x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2105,7 +2146,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F16F16F16_SS +struct MMA_64x256x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2134,6 +2175,7 @@ struct SM90_64x256x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2171,7 +2213,7 @@ struct SM90_64x256x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2185,7 +2227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F16F16F16_RS +struct MMA_64x256x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2217,6 +2259,7 @@ struct SM90_64x256x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2254,7 +2297,7 @@ struct SM90_64x256x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2268,7 +2311,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32F16F16_SS +struct MMA_64x8x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2282,6 +2325,7 @@ struct SM90_64x8x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2297,7 +2341,7 @@ struct SM90_64x8x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2311,7 +2355,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32F16F16_RS +struct MMA_64x8x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2328,6 +2372,7 @@ struct SM90_64x8x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2343,7 +2388,7 @@ struct SM90_64x8x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2357,7 +2402,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32F16F16_SS +struct MMA_64x16x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2372,6 +2417,7 @@ struct SM90_64x16x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2388,7 +2434,7 @@ struct SM90_64x16x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2402,7 +2448,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32F16F16_RS +struct MMA_64x16x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2420,6 +2466,7 @@ struct SM90_64x16x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2436,7 +2483,7 @@ struct SM90_64x16x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2450,7 +2497,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32F16F16_SS +struct MMA_64x32x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2467,6 +2514,7 @@ struct SM90_64x32x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2486,7 +2534,7 @@ struct SM90_64x32x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2500,7 +2548,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32F16F16_RS +struct MMA_64x32x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2520,6 +2568,7 @@ struct SM90_64x32x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2539,7 +2588,7 @@ struct SM90_64x32x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2554,7 +2603,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32F16F16_SS +struct MMA_64x48x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2573,6 +2622,7 @@ struct SM90_64x48x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2595,7 +2645,7 @@ struct SM90_64x48x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2611,7 +2661,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32F16F16_RS +struct MMA_64x48x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2633,6 +2683,7 @@ struct SM90_64x48x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2655,7 +2706,7 @@ struct SM90_64x48x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2670,7 +2721,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32F16F16_SS +struct MMA_64x64x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2691,6 +2742,7 @@ struct SM90_64x64x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2716,7 +2768,7 @@ struct SM90_64x64x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2730,7 +2782,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32F16F16_RS +struct MMA_64x64x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2754,6 +2806,7 @@ struct SM90_64x64x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2779,7 +2832,7 @@ struct SM90_64x64x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2794,7 +2847,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32F16F16_SS +struct MMA_64x80x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2817,6 +2870,7 @@ struct SM90_64x80x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2845,7 +2899,7 @@ struct SM90_64x80x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2861,7 +2915,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32F16F16_RS +struct MMA_64x80x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2887,6 +2941,7 @@ struct SM90_64x80x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2915,7 +2970,7 @@ struct SM90_64x80x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2930,7 +2985,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32F16F16_SS +struct MMA_64x96x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2955,6 +3010,7 @@ struct SM90_64x96x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2986,7 +3042,7 @@ struct SM90_64x96x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3000,7 +3056,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32F16F16_RS +struct MMA_64x96x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3028,6 +3084,7 @@ struct SM90_64x96x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3059,7 +3116,7 @@ struct SM90_64x96x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3074,7 +3131,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32F16F16_SS +struct MMA_64x112x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3101,6 +3158,7 @@ struct SM90_64x112x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3135,7 +3193,7 @@ struct SM90_64x112x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3151,7 +3209,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32F16F16_RS +struct MMA_64x112x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3181,6 +3239,7 @@ struct SM90_64x112x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3215,7 +3274,7 @@ struct SM90_64x112x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3230,7 +3289,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32F16F16_SS +struct MMA_64x128x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3259,6 +3318,7 @@ struct SM90_64x128x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3296,7 +3356,7 @@ struct SM90_64x128x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3310,7 +3370,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32F16F16_RS +struct MMA_64x128x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3342,6 +3402,7 @@ struct SM90_64x128x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3379,7 +3440,7 @@ struct SM90_64x128x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3394,7 +3455,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32F16F16_SS +struct MMA_64x144x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3425,6 +3486,7 @@ struct SM90_64x144x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3465,7 +3527,7 @@ struct SM90_64x144x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3481,7 +3543,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32F16F16_RS +struct MMA_64x144x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3515,6 +3577,7 @@ struct SM90_64x144x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3555,7 +3618,7 @@ struct SM90_64x144x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3571,7 +3634,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32F16F16_SS +struct MMA_64x160x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3604,6 +3667,7 @@ struct SM90_64x160x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3647,7 +3711,7 @@ struct SM90_64x160x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3663,7 +3727,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32F16F16_RS +struct MMA_64x160x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3699,6 +3763,7 @@ struct SM90_64x160x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3742,7 +3807,7 @@ struct SM90_64x160x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3758,7 +3823,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32F16F16_SS +struct MMA_64x176x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3793,6 +3858,7 @@ struct SM90_64x176x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3839,7 +3905,7 @@ struct SM90_64x176x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3855,7 +3921,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32F16F16_RS +struct MMA_64x176x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3893,6 +3959,7 @@ struct SM90_64x176x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3939,7 +4006,7 @@ struct SM90_64x176x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3954,7 +4021,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32F16F16_SS +struct MMA_64x192x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3991,6 +4058,7 @@ struct SM90_64x192x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4040,7 +4108,7 @@ struct SM90_64x192x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4054,7 +4122,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32F16F16_RS +struct MMA_64x192x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4094,6 +4162,7 @@ struct SM90_64x192x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4143,7 +4212,7 @@ struct SM90_64x192x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4158,7 +4227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32F16F16_SS +struct MMA_64x208x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4197,6 +4266,7 @@ struct SM90_64x208x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4249,7 +4319,7 @@ struct SM90_64x208x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4265,7 +4335,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32F16F16_RS +struct MMA_64x208x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4307,6 +4377,7 @@ struct SM90_64x208x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4359,7 +4430,7 @@ struct SM90_64x208x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4375,7 +4446,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32F16F16_SS +struct MMA_64x224x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4416,6 +4487,7 @@ struct SM90_64x224x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4471,7 +4543,7 @@ struct SM90_64x224x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4487,7 +4559,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32F16F16_RS +struct MMA_64x224x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4531,6 +4603,7 @@ struct SM90_64x224x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4586,7 +4659,7 @@ struct SM90_64x224x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4602,7 +4675,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32F16F16_SS +struct MMA_64x240x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4645,6 +4718,7 @@ struct SM90_64x240x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4703,7 +4777,7 @@ struct SM90_64x240x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4719,7 +4793,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32F16F16_RS +struct MMA_64x240x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4765,6 +4839,7 @@ struct SM90_64x240x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4823,7 +4898,7 @@ struct SM90_64x240x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4838,7 +4913,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32F16F16_SS +struct MMA_64x256x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4883,6 +4958,7 @@ struct SM90_64x256x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4944,7 +5020,7 @@ struct SM90_64x256x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4958,7 +5034,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32F16F16_RS +struct MMA_64x256x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5006,6 +5082,7 @@ struct SM90_64x256x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5067,7 +5144,7 @@ struct SM90_64x256x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5081,7 +5158,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32BF16BF16_SS +struct MMA_64x8x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5095,6 +5172,7 @@ struct SM90_64x8x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5110,7 +5188,7 @@ struct SM90_64x8x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5124,7 +5202,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32BF16BF16_RS +struct MMA_64x8x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5141,6 +5219,7 @@ struct SM90_64x8x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5156,7 +5235,7 @@ struct SM90_64x8x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5170,7 +5249,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32BF16BF16_SS +struct MMA_64x16x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5185,6 +5264,7 @@ struct SM90_64x16x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5201,7 +5281,7 @@ struct SM90_64x16x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5215,7 +5295,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32BF16BF16_RS +struct MMA_64x16x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5233,6 +5313,7 @@ struct SM90_64x16x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5249,7 +5330,7 @@ struct SM90_64x16x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5263,7 +5344,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32BF16BF16_SS +struct MMA_64x32x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5280,6 +5361,7 @@ struct SM90_64x32x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5299,7 +5381,7 @@ struct SM90_64x32x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5313,7 +5395,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32BF16BF16_RS +struct MMA_64x32x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5333,6 +5415,7 @@ struct SM90_64x32x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5352,7 +5435,7 @@ struct SM90_64x32x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5367,7 +5450,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32BF16BF16_SS +struct MMA_64x48x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5386,6 +5469,7 @@ struct SM90_64x48x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5408,7 +5492,7 @@ struct SM90_64x48x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5424,7 +5508,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32BF16BF16_RS +struct MMA_64x48x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5446,6 +5530,7 @@ struct SM90_64x48x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5468,7 +5553,7 @@ struct SM90_64x48x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5483,7 +5568,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32BF16BF16_SS +struct MMA_64x64x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5504,6 +5589,7 @@ struct SM90_64x64x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5529,7 +5615,7 @@ struct SM90_64x64x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5543,7 +5629,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32BF16BF16_RS +struct MMA_64x64x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5567,6 +5653,7 @@ struct SM90_64x64x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5592,7 +5679,7 @@ struct SM90_64x64x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5607,7 +5694,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32BF16BF16_SS +struct MMA_64x80x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5630,6 +5717,7 @@ struct SM90_64x80x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5658,7 +5746,7 @@ struct SM90_64x80x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5674,7 +5762,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32BF16BF16_RS +struct MMA_64x80x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5700,6 +5788,7 @@ struct SM90_64x80x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5728,7 +5817,7 @@ struct SM90_64x80x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5743,7 +5832,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32BF16BF16_SS +struct MMA_64x96x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5768,6 +5857,7 @@ struct SM90_64x96x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5799,7 +5889,7 @@ struct SM90_64x96x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5813,7 +5903,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32BF16BF16_RS +struct MMA_64x96x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5841,6 +5931,7 @@ struct SM90_64x96x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5872,7 +5963,7 @@ struct SM90_64x96x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5887,7 +5978,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32BF16BF16_SS +struct MMA_64x112x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5914,6 +6005,7 @@ struct SM90_64x112x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5948,7 +6040,7 @@ struct SM90_64x112x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5964,7 +6056,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32BF16BF16_RS +struct MMA_64x112x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5994,6 +6086,7 @@ struct SM90_64x112x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6028,7 +6121,7 @@ struct SM90_64x112x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6043,7 +6136,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32BF16BF16_SS +struct MMA_64x128x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6072,6 +6165,7 @@ struct SM90_64x128x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6109,7 +6203,7 @@ struct SM90_64x128x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6123,7 +6217,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32BF16BF16_RS +struct MMA_64x128x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6155,6 +6249,7 @@ struct SM90_64x128x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6192,7 +6287,7 @@ struct SM90_64x128x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6207,7 +6302,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32BF16BF16_SS +struct MMA_64x144x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6238,6 +6333,7 @@ struct SM90_64x144x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6278,7 +6374,7 @@ struct SM90_64x144x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6294,7 +6390,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32BF16BF16_RS +struct MMA_64x144x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6328,6 +6424,7 @@ struct SM90_64x144x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6368,7 +6465,7 @@ struct SM90_64x144x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6384,7 +6481,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32BF16BF16_SS +struct MMA_64x160x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6417,6 +6514,7 @@ struct SM90_64x160x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6460,7 +6558,7 @@ struct SM90_64x160x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6476,7 +6574,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32BF16BF16_RS +struct MMA_64x160x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6512,6 +6610,7 @@ struct SM90_64x160x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6555,7 +6654,7 @@ struct SM90_64x160x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6571,7 +6670,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32BF16BF16_SS +struct MMA_64x176x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6606,6 +6705,7 @@ struct SM90_64x176x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6652,7 +6752,7 @@ struct SM90_64x176x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6668,7 +6768,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32BF16BF16_RS +struct MMA_64x176x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6706,6 +6806,7 @@ struct SM90_64x176x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6752,7 +6853,7 @@ struct SM90_64x176x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6767,7 +6868,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32BF16BF16_SS +struct MMA_64x192x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6804,6 +6905,7 @@ struct SM90_64x192x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6853,7 +6955,7 @@ struct SM90_64x192x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6867,7 +6969,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32BF16BF16_RS +struct MMA_64x192x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6907,6 +7009,7 @@ struct SM90_64x192x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6956,7 +7059,7 @@ struct SM90_64x192x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6971,7 +7074,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32BF16BF16_SS +struct MMA_64x208x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7010,6 +7113,7 @@ struct SM90_64x208x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7062,7 +7166,7 @@ struct SM90_64x208x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7078,7 +7182,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32BF16BF16_RS +struct MMA_64x208x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7120,6 +7224,7 @@ struct SM90_64x208x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7172,7 +7277,7 @@ struct SM90_64x208x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7188,7 +7293,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32BF16BF16_SS +struct MMA_64x224x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7229,6 +7334,7 @@ struct SM90_64x224x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7284,7 +7390,7 @@ struct SM90_64x224x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7300,7 +7406,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32BF16BF16_RS +struct MMA_64x224x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7344,6 +7450,7 @@ struct SM90_64x224x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7399,7 +7506,7 @@ struct SM90_64x224x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7415,7 +7522,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32BF16BF16_SS +struct MMA_64x240x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7458,6 +7565,7 @@ struct SM90_64x240x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7516,7 +7624,7 @@ struct SM90_64x240x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7532,7 +7640,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32BF16BF16_RS +struct MMA_64x240x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7578,6 +7686,7 @@ struct SM90_64x240x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7636,7 +7745,7 @@ struct SM90_64x240x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7651,7 +7760,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32BF16BF16_SS +struct MMA_64x256x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7696,6 +7805,7 @@ struct SM90_64x256x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7757,7 +7867,7 @@ struct SM90_64x256x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7771,7 +7881,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32BF16BF16_RS +struct MMA_64x256x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7819,6 +7929,7 @@ struct SM90_64x256x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7880,7 +7991,7 @@ struct SM90_64x256x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7892,7 +8003,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x8_F32TF32TF32_SS_TN +struct MMA_64x8x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7906,6 +8017,7 @@ struct SM90_64x8x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7921,7 +8033,7 @@ struct SM90_64x8x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7933,7 +8045,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x8_F32TF32TF32_RS_TN +struct MMA_64x8x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7947,6 +8059,7 @@ struct SM90_64x8x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7962,7 +8075,7 @@ struct SM90_64x8x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7974,7 +8087,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x8_F32TF32TF32_SS_TN +struct MMA_64x16x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7989,6 +8102,7 @@ struct SM90_64x16x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8005,7 +8119,7 @@ struct SM90_64x16x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8017,7 +8131,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x8_F32TF32TF32_RS_TN +struct MMA_64x16x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8032,6 +8146,7 @@ struct SM90_64x16x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8048,7 +8163,7 @@ struct SM90_64x16x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8060,7 +8175,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x8_F32TF32TF32_SS_TN +struct MMA_64x32x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8077,6 +8192,7 @@ struct SM90_64x32x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8096,7 +8212,7 @@ struct SM90_64x32x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8108,7 +8224,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x8_F32TF32TF32_RS_TN +struct MMA_64x32x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8125,6 +8241,7 @@ struct SM90_64x32x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8144,7 +8261,7 @@ struct SM90_64x32x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8157,7 +8274,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x8_F32TF32TF32_SS_TN +struct MMA_64x48x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8176,6 +8293,7 @@ struct SM90_64x48x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8198,7 +8316,7 @@ struct SM90_64x48x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8212,7 +8330,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x8_F32TF32TF32_RS_TN +struct MMA_64x48x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8231,6 +8349,7 @@ struct SM90_64x48x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8253,7 +8372,7 @@ struct SM90_64x48x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8266,7 +8385,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x8_F32TF32TF32_SS_TN +struct MMA_64x64x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8287,6 +8406,7 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8312,7 +8432,7 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8324,7 +8444,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x8_F32TF32TF32_RS_TN +struct MMA_64x64x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8345,6 +8465,7 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8370,7 +8491,7 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8383,7 +8504,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x8_F32TF32TF32_SS_TN +struct MMA_64x80x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8406,6 +8527,7 @@ struct SM90_64x80x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8434,7 +8556,7 @@ struct SM90_64x80x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8448,7 +8570,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x8_F32TF32TF32_RS_TN +struct MMA_64x80x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8471,6 +8593,7 @@ struct SM90_64x80x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8499,7 +8622,7 @@ struct SM90_64x80x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8512,7 +8635,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x8_F32TF32TF32_SS_TN +struct MMA_64x96x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8537,6 +8660,7 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8568,7 +8692,7 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8580,7 +8704,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x8_F32TF32TF32_RS_TN +struct MMA_64x96x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8605,6 +8729,7 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8636,7 +8761,7 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8649,7 +8774,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x8_F32TF32TF32_SS_TN +struct MMA_64x112x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8676,6 +8801,7 @@ struct SM90_64x112x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8710,7 +8836,7 @@ struct SM90_64x112x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8724,7 +8850,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x8_F32TF32TF32_RS_TN +struct MMA_64x112x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8751,6 +8877,7 @@ struct SM90_64x112x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8785,7 +8912,7 @@ struct SM90_64x112x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8798,7 +8925,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x8_F32TF32TF32_SS_TN +struct MMA_64x128x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8827,6 +8954,7 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8864,7 +8992,7 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8876,7 +9004,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x8_F32TF32TF32_RS_TN +struct MMA_64x128x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8905,6 +9033,7 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8942,7 +9071,7 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8955,7 +9084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x8_F32TF32TF32_SS_TN +struct MMA_64x144x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8986,6 +9115,7 @@ struct SM90_64x144x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9026,7 +9156,7 @@ struct SM90_64x144x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9040,7 +9170,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x8_F32TF32TF32_RS_TN +struct MMA_64x144x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9071,6 +9201,7 @@ struct SM90_64x144x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9111,7 +9242,7 @@ struct SM90_64x144x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9125,7 +9256,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x8_F32TF32TF32_SS_TN +struct MMA_64x160x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9158,6 +9289,7 @@ struct SM90_64x160x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9201,7 +9333,7 @@ struct SM90_64x160x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9215,7 +9347,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x8_F32TF32TF32_RS_TN +struct MMA_64x160x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9248,6 +9380,7 @@ struct SM90_64x160x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9291,7 +9424,7 @@ struct SM90_64x160x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9305,7 +9438,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x8_F32TF32TF32_SS_TN +struct MMA_64x176x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9340,6 +9473,7 @@ struct SM90_64x176x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9386,7 +9520,7 @@ struct SM90_64x176x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9400,7 +9534,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x8_F32TF32TF32_RS_TN +struct MMA_64x176x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9435,6 +9569,7 @@ struct SM90_64x176x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9481,7 +9616,7 @@ struct SM90_64x176x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9494,7 +9629,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x8_F32TF32TF32_SS_TN +struct MMA_64x192x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9531,6 +9666,7 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9580,7 +9716,7 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9592,7 +9728,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x8_F32TF32TF32_RS_TN +struct MMA_64x192x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9629,6 +9765,7 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9678,7 +9815,7 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9691,7 +9828,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x8_F32TF32TF32_SS_TN +struct MMA_64x208x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9730,6 +9867,7 @@ struct SM90_64x208x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9782,7 +9920,7 @@ struct SM90_64x208x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9796,7 +9934,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x8_F32TF32TF32_RS_TN +struct MMA_64x208x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9835,6 +9973,7 @@ struct SM90_64x208x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9887,7 +10026,7 @@ struct SM90_64x208x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9901,7 +10040,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x8_F32TF32TF32_SS_TN +struct MMA_64x224x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9942,6 +10081,7 @@ struct SM90_64x224x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9997,7 +10137,7 @@ struct SM90_64x224x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10011,7 +10151,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x8_F32TF32TF32_RS_TN +struct MMA_64x224x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -10052,6 +10192,7 @@ struct SM90_64x224x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10107,7 +10248,7 @@ struct SM90_64x224x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10121,7 +10262,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x8_F32TF32TF32_SS_TN +struct MMA_64x240x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10164,6 +10305,7 @@ struct SM90_64x240x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10222,7 +10364,7 @@ struct SM90_64x240x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10236,7 +10378,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x8_F32TF32TF32_RS_TN +struct MMA_64x240x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -10279,6 +10421,7 @@ struct SM90_64x240x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10337,7 +10480,7 @@ struct SM90_64x240x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10350,7 +10493,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x8_F32TF32TF32_SS_TN +struct MMA_64x256x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10395,6 +10538,7 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10456,7 +10600,7 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10468,7 +10612,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x8_F32TF32TF32_RS_TN +struct MMA_64x256x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -10513,6 +10657,7 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10574,7 +10719,7 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10582,7 +10727,7 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_SS_TN +struct MMA_64x8x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10596,6 +10741,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10611,7 +10757,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10619,7 +10765,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10633,6 +10779,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10648,7 +10795,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10656,7 +10803,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_SS_TN +struct MMA_64x16x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10671,6 +10818,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10687,7 +10835,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10695,7 +10843,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10710,6 +10858,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10726,7 +10875,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10734,7 +10883,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_SS_TN +struct MMA_64x32x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10751,6 +10900,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10770,7 +10920,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10778,7 +10928,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10795,6 +10945,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10814,7 +10965,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10823,7 +10974,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_SS_TN +struct MMA_64x48x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10842,6 +10993,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10864,7 +11016,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10874,7 +11026,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10893,6 +11045,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10915,7 +11068,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10924,7 +11077,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_SS_TN +struct MMA_64x64x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10945,6 +11098,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10970,7 +11124,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10978,7 +11132,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10999,6 +11153,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11024,7 +11179,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11033,7 +11188,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_SS_TN +struct MMA_64x80x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11056,6 +11211,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11084,7 +11240,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11094,7 +11250,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11117,6 +11273,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11145,7 +11302,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11154,7 +11311,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_SS_TN +struct MMA_64x96x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11179,6 +11336,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11210,7 +11368,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11218,7 +11376,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11243,6 +11401,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11274,7 +11433,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11283,7 +11442,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_SS_TN +struct MMA_64x112x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11310,6 +11469,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11344,7 +11504,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11354,7 +11514,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11381,6 +11541,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11415,7 +11576,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11424,7 +11585,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_SS_TN +struct MMA_64x128x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11453,6 +11614,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11490,7 +11652,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11498,7 +11660,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11527,6 +11689,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11564,7 +11727,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11573,7 +11736,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_SS_TN +struct MMA_64x144x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11604,6 +11767,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11644,7 +11808,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11654,7 +11818,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11685,6 +11849,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11725,7 +11890,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11735,7 +11900,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_SS_TN +struct MMA_64x160x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11768,6 +11933,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11811,7 +11977,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11821,7 +11987,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11854,6 +12020,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11897,7 +12064,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11907,7 +12074,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_SS_TN +struct MMA_64x176x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11942,6 +12109,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11988,7 +12156,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11998,7 +12166,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12033,6 +12201,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12079,7 +12248,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12088,7 +12257,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_SS_TN +struct MMA_64x192x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12125,6 +12294,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12174,7 +12344,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12182,7 +12352,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12219,6 +12389,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12268,7 +12439,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12277,7 +12448,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_SS_TN +struct MMA_64x208x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12316,6 +12487,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12368,7 +12540,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12378,7 +12550,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12417,6 +12589,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12469,7 +12642,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12479,7 +12652,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_SS_TN +struct MMA_64x224x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12520,6 +12693,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12575,7 +12749,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12585,7 +12759,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12626,6 +12800,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12681,7 +12856,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12691,7 +12866,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_SS_TN +struct MMA_64x240x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12734,6 +12909,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12792,7 +12968,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12802,7 +12978,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12845,6 +13021,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12903,7 +13080,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12912,7 +13089,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_SS_TN +struct MMA_64x256x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12957,6 +13134,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13018,7 +13196,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13026,7 +13204,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -13071,6 +13249,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13132,7 +13311,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13140,7 +13319,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_RS_TN +struct MMA_64x8x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13154,6 +13333,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13169,7 +13349,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13177,7 +13357,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13191,6 +13371,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13206,7 +13387,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13214,7 +13395,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_RS_TN +struct MMA_64x16x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13229,6 +13410,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13245,7 +13427,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13253,7 +13435,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13268,6 +13450,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13284,7 +13467,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13292,7 +13475,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_RS_TN +struct MMA_64x32x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13309,6 +13492,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13328,7 +13512,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13336,7 +13520,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13353,6 +13537,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13372,7 +13557,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13381,7 +13566,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_RS_TN +struct MMA_64x48x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13400,6 +13585,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13422,7 +13608,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13432,7 +13618,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13451,6 +13637,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13473,7 +13660,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13482,7 +13669,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_RS_TN +struct MMA_64x64x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13503,6 +13690,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13528,7 +13716,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13536,7 +13724,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13557,6 +13745,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13582,7 +13771,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13591,7 +13780,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_RS_TN +struct MMA_64x80x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13614,6 +13803,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13642,7 +13832,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13652,7 +13842,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13675,6 +13865,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13703,7 +13894,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13712,7 +13903,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_RS_TN +struct MMA_64x96x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13737,6 +13928,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13768,7 +13960,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13776,7 +13968,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13801,6 +13993,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13832,7 +14025,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13841,7 +14034,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_RS_TN +struct MMA_64x112x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13868,6 +14061,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13902,7 +14096,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13912,7 +14106,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13939,6 +14133,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13973,7 +14168,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13982,7 +14177,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_RS_TN +struct MMA_64x128x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14011,6 +14206,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14048,7 +14244,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14056,7 +14252,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14085,6 +14281,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14122,7 +14319,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14131,7 +14328,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_RS_TN +struct MMA_64x144x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14162,6 +14359,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14202,7 +14400,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14212,7 +14410,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14243,6 +14441,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14283,7 +14482,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14293,7 +14492,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_RS_TN +struct MMA_64x160x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14326,6 +14525,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14369,7 +14569,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14379,7 +14579,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14412,6 +14612,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14455,7 +14656,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14465,7 +14666,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_RS_TN +struct MMA_64x176x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14500,6 +14701,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14546,7 +14748,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14556,7 +14758,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14591,6 +14793,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14637,7 +14840,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14646,7 +14849,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_RS_TN +struct MMA_64x192x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14683,6 +14886,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14732,7 +14936,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14740,7 +14944,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14777,6 +14981,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14826,7 +15031,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14835,7 +15040,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_RS_TN +struct MMA_64x208x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14874,6 +15079,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14926,7 +15132,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14936,7 +15142,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14975,6 +15181,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15027,7 +15234,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15037,7 +15244,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_RS_TN +struct MMA_64x224x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15078,6 +15285,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15133,7 +15341,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15143,7 +15351,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15184,6 +15392,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15239,7 +15448,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15249,7 +15458,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_RS_TN +struct MMA_64x240x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15292,6 +15501,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15350,7 +15560,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15360,7 +15570,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15403,6 +15613,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15461,7 +15672,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15470,7 +15681,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_RS_TN +struct MMA_64x256x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15515,6 +15726,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15576,7 +15788,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15584,7 +15796,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15629,6 +15841,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15690,7 +15903,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15698,7 +15911,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_SS_TN +struct MMA_64x8x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15712,6 +15925,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15727,7 +15941,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15735,7 +15949,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15749,6 +15963,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15764,7 +15979,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15772,7 +15987,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_SS_TN +struct MMA_64x16x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15787,6 +16002,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15803,7 +16019,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15811,7 +16027,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15826,6 +16042,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15842,7 +16059,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15850,7 +16067,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_SS_TN +struct MMA_64x32x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15867,6 +16084,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15886,7 +16104,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15894,7 +16112,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15911,6 +16129,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15930,7 +16149,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15939,7 +16158,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_SS_TN +struct MMA_64x48x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15958,6 +16177,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15980,7 +16200,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15990,7 +16210,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16009,6 +16229,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16031,7 +16252,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16040,7 +16261,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_SS_TN +struct MMA_64x64x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16061,6 +16282,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16086,7 +16308,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16094,7 +16316,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16115,6 +16337,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16140,7 +16363,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16149,7 +16372,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_SS_TN +struct MMA_64x80x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16172,6 +16395,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16200,7 +16424,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16210,7 +16434,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16233,6 +16457,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16261,7 +16486,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16270,7 +16495,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_SS_TN +struct MMA_64x96x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16295,6 +16520,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16326,7 +16552,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16334,7 +16560,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16359,6 +16585,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16390,7 +16617,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16399,7 +16626,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_SS_TN +struct MMA_64x112x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16426,6 +16653,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16460,7 +16688,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16470,7 +16698,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16497,6 +16725,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16531,7 +16760,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16540,7 +16769,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_SS_TN +struct MMA_64x128x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16569,6 +16798,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16606,7 +16836,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16614,7 +16844,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16643,6 +16873,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16680,7 +16911,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16689,7 +16920,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_SS_TN +struct MMA_64x144x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16720,6 +16951,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16760,7 +16992,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16770,7 +17002,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16801,6 +17033,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16841,7 +17074,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16851,7 +17084,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_SS_TN +struct MMA_64x160x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16884,6 +17117,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16927,7 +17161,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16937,7 +17171,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16970,6 +17204,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17013,7 +17248,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17023,7 +17258,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_SS_TN +struct MMA_64x176x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17058,6 +17293,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17104,7 +17340,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17114,7 +17350,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17149,6 +17385,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17195,7 +17432,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17204,7 +17441,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_SS_TN +struct MMA_64x192x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17241,6 +17478,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17290,7 +17528,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17298,7 +17536,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17335,6 +17573,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17384,7 +17623,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17393,7 +17632,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_SS_TN +struct MMA_64x208x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17432,6 +17671,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17484,7 +17724,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17494,7 +17734,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17533,6 +17773,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17585,7 +17826,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17595,7 +17836,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_SS_TN +struct MMA_64x224x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17636,6 +17877,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17691,7 +17933,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17701,7 +17943,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17742,6 +17984,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17797,7 +18040,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17807,7 +18050,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_SS_TN +struct MMA_64x240x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17850,6 +18093,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17908,7 +18152,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17918,7 +18162,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17961,6 +18205,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18019,7 +18264,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18028,7 +18273,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_SS_TN +struct MMA_64x256x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -18073,6 +18318,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18134,7 +18380,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18142,7 +18388,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -18187,6 +18433,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18248,7 +18495,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18256,7 +18503,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_RS_TN +struct MMA_64x8x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18270,6 +18517,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18285,7 +18533,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18293,7 +18541,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18307,6 +18555,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18322,7 +18571,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18330,7 +18579,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_RS_TN +struct MMA_64x16x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18345,6 +18594,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18361,7 +18611,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18369,7 +18619,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18384,6 +18634,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18400,7 +18651,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18408,7 +18659,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_RS_TN +struct MMA_64x32x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18425,6 +18676,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18444,7 +18696,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18452,7 +18704,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18469,6 +18721,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18488,7 +18741,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18497,7 +18750,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_RS_TN +struct MMA_64x48x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18516,6 +18769,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18538,7 +18792,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18548,7 +18802,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18567,6 +18821,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18589,7 +18844,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18598,7 +18853,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_RS_TN +struct MMA_64x64x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18619,6 +18874,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18644,7 +18900,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18652,7 +18908,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18673,6 +18929,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18698,7 +18955,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18707,7 +18964,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_RS_TN +struct MMA_64x80x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18730,6 +18987,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18758,7 +19016,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18768,7 +19026,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18791,6 +19049,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18819,7 +19078,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18828,7 +19087,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_RS_TN +struct MMA_64x96x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18853,6 +19112,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18884,7 +19144,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18892,7 +19152,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18917,6 +19177,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18948,7 +19209,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18957,7 +19218,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_RS_TN +struct MMA_64x112x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18984,6 +19245,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19018,7 +19280,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19028,7 +19290,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19055,6 +19317,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19089,7 +19352,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19098,7 +19361,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_RS_TN +struct MMA_64x128x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19127,6 +19390,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19164,7 +19428,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19172,7 +19436,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19201,6 +19465,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19238,7 +19503,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19247,7 +19512,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_RS_TN +struct MMA_64x144x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19278,6 +19543,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19318,7 +19584,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19328,7 +19594,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19359,6 +19625,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19399,7 +19666,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19409,7 +19676,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_RS_TN +struct MMA_64x160x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19442,6 +19709,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19485,7 +19753,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19495,7 +19763,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19528,6 +19796,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19571,7 +19840,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19581,7 +19850,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_RS_TN +struct MMA_64x176x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19616,6 +19885,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19662,7 +19932,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19672,7 +19942,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19707,6 +19977,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19753,7 +20024,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19762,7 +20033,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_RS_TN +struct MMA_64x192x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19799,6 +20070,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19848,7 +20120,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19856,7 +20128,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19893,6 +20165,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19942,7 +20215,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19951,7 +20224,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_RS_TN +struct MMA_64x208x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19990,6 +20263,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20042,7 +20316,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20052,7 +20326,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20091,6 +20365,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20143,7 +20418,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20153,7 +20428,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_RS_TN +struct MMA_64x224x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20194,6 +20469,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20249,7 +20525,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20259,7 +20535,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20300,6 +20576,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20355,7 +20632,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20365,7 +20642,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_RS_TN +struct MMA_64x240x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20408,6 +20685,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20466,7 +20744,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20476,7 +20754,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20519,6 +20797,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20577,7 +20856,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20586,7 +20865,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_RS_TN +struct MMA_64x256x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20631,6 +20910,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20692,7 +20972,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20700,7 +20980,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20745,6 +21025,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20806,7 +21087,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20814,7 +21095,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_SS_TN +struct MMA_64x8x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20828,6 +21109,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20843,7 +21125,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20851,7 +21133,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20865,6 +21147,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20880,7 +21163,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20888,7 +21171,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_SS_TN +struct MMA_64x16x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20903,6 +21186,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20919,7 +21203,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20927,7 +21211,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20942,6 +21226,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20958,7 +21243,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20966,7 +21251,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_SS_TN +struct MMA_64x32x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20983,6 +21268,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21002,7 +21288,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21010,7 +21296,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21027,6 +21313,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21046,7 +21333,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21055,7 +21342,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_SS_TN +struct MMA_64x48x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21074,6 +21361,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21096,7 +21384,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21106,7 +21394,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21125,6 +21413,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21147,7 +21436,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21156,7 +21445,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_SS_TN +struct MMA_64x64x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21177,6 +21466,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21202,7 +21492,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21210,7 +21500,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21231,6 +21521,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21256,7 +21547,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21265,7 +21556,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_SS_TN +struct MMA_64x80x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21288,6 +21579,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21316,7 +21608,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21326,7 +21618,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21349,6 +21641,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21377,7 +21670,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21386,7 +21679,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_SS_TN +struct MMA_64x96x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21411,6 +21704,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21442,7 +21736,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21450,7 +21744,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21475,6 +21769,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21506,7 +21801,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21515,7 +21810,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_SS_TN +struct MMA_64x112x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21542,6 +21837,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21576,7 +21872,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21586,7 +21882,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21613,6 +21909,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21647,7 +21944,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21656,7 +21953,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_SS_TN +struct MMA_64x128x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21685,6 +21982,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21722,7 +22020,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21730,7 +22028,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21759,6 +22057,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21796,7 +22095,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21805,7 +22104,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_SS_TN +struct MMA_64x144x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21836,6 +22135,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21876,7 +22176,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21886,7 +22186,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21917,6 +22217,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21957,7 +22258,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21967,7 +22268,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_SS_TN +struct MMA_64x160x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22000,6 +22301,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22043,7 +22345,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22053,7 +22355,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22086,6 +22388,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22129,7 +22432,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22139,7 +22442,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_SS_TN +struct MMA_64x176x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22174,6 +22477,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22220,7 +22524,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22230,7 +22534,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22265,6 +22569,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22311,7 +22616,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22320,7 +22625,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_SS_TN +struct MMA_64x192x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22357,6 +22662,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22406,7 +22712,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22414,7 +22720,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22451,6 +22757,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22500,7 +22807,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22509,7 +22816,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_SS_TN +struct MMA_64x208x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22548,6 +22855,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22600,7 +22908,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22610,7 +22918,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22649,6 +22957,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22701,7 +23010,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22711,7 +23020,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_SS_TN +struct MMA_64x224x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22752,6 +23061,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22807,7 +23117,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22817,7 +23127,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22858,6 +23168,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22913,7 +23224,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22923,7 +23234,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_SS_TN +struct MMA_64x240x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22966,6 +23277,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23024,7 +23336,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23034,7 +23346,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -23077,6 +23389,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23135,7 +23448,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23144,7 +23457,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_SS_TN +struct MMA_64x256x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -23189,6 +23502,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23250,7 +23564,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23258,7 +23572,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -23303,6 +23617,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23364,7 +23679,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23372,7 +23687,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_RS_TN +struct MMA_64x8x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23386,6 +23701,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23401,7 +23717,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23409,7 +23725,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23423,6 +23739,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23438,7 +23755,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23446,7 +23763,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_RS_TN +struct MMA_64x16x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23461,6 +23778,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23477,7 +23795,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23485,7 +23803,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23500,6 +23818,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23516,7 +23835,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23524,7 +23843,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_RS_TN +struct MMA_64x32x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23541,6 +23860,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23560,7 +23880,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23568,7 +23888,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23585,6 +23905,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23604,7 +23925,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23613,7 +23934,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_RS_TN +struct MMA_64x48x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23632,6 +23953,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23654,7 +23976,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23664,7 +23986,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23683,6 +24005,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23705,7 +24028,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23714,7 +24037,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_RS_TN +struct MMA_64x64x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23735,6 +24058,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23760,7 +24084,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23768,7 +24092,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23789,6 +24113,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23814,7 +24139,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23823,7 +24148,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_RS_TN +struct MMA_64x80x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23846,6 +24171,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23874,7 +24200,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23884,7 +24210,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23907,6 +24233,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23935,7 +24262,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23944,7 +24271,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_RS_TN +struct MMA_64x96x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23969,6 +24296,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24000,7 +24328,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24008,7 +24336,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24033,6 +24361,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24064,7 +24393,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24073,7 +24402,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_RS_TN +struct MMA_64x112x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24100,6 +24429,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24134,7 +24464,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24144,7 +24474,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24171,6 +24501,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24205,7 +24536,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24214,7 +24545,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_RS_TN +struct MMA_64x128x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24243,6 +24574,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24280,7 +24612,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24288,7 +24620,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24317,6 +24649,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24354,7 +24687,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24363,7 +24696,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_RS_TN +struct MMA_64x144x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24394,6 +24727,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24434,7 +24768,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24444,7 +24778,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24475,6 +24809,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24515,7 +24850,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24525,7 +24860,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_RS_TN +struct MMA_64x160x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24558,6 +24893,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24601,7 +24937,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24611,7 +24947,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24644,6 +24980,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24687,7 +25024,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24697,7 +25034,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_RS_TN +struct MMA_64x176x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24732,6 +25069,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24778,7 +25116,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24788,7 +25126,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24823,6 +25161,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24869,7 +25208,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24878,7 +25217,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_RS_TN +struct MMA_64x192x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24915,6 +25254,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24964,7 +25304,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24972,7 +25312,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25009,6 +25349,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25058,7 +25399,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25067,7 +25408,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_RS_TN +struct MMA_64x208x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25106,6 +25447,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25158,7 +25500,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25168,7 +25510,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25207,6 +25549,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25259,7 +25602,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25269,7 +25612,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_RS_TN +struct MMA_64x224x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25310,6 +25653,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25365,7 +25709,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25375,7 +25719,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25416,6 +25760,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25471,7 +25816,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25481,7 +25826,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_RS_TN +struct MMA_64x240x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25524,6 +25869,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25582,7 +25928,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25592,7 +25938,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25635,6 +25981,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25693,7 +26040,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25702,7 +26049,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_RS_TN +struct MMA_64x256x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25747,6 +26094,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25808,7 +26156,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25816,7 +26164,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25861,6 +26209,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25922,7 +26271,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25930,7 +26279,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_SS_TN +struct MMA_64x8x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -25944,6 +26293,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25959,7 +26309,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25967,7 +26317,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -25981,6 +26331,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25996,7 +26347,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26004,7 +26355,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_SS_TN +struct MMA_64x16x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26019,6 +26370,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26035,7 +26387,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26043,7 +26395,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26058,6 +26410,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26074,7 +26427,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26082,7 +26435,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_SS_TN +struct MMA_64x32x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26099,6 +26452,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26118,7 +26472,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26126,7 +26480,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26143,6 +26497,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26162,7 +26517,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26171,7 +26526,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_SS_TN +struct MMA_64x48x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26190,6 +26545,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26212,7 +26568,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26222,7 +26578,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26241,6 +26597,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26263,7 +26620,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26272,7 +26629,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_SS_TN +struct MMA_64x64x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26293,6 +26650,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26318,7 +26676,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26326,7 +26684,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26347,6 +26705,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26372,7 +26731,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26381,7 +26740,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_SS_TN +struct MMA_64x80x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26404,6 +26763,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26432,7 +26792,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26442,7 +26802,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26465,6 +26825,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26493,7 +26854,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26502,7 +26863,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_SS_TN +struct MMA_64x96x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26527,6 +26888,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26558,7 +26920,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26566,7 +26928,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26591,6 +26953,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26622,7 +26985,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26631,7 +26994,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_SS_TN +struct MMA_64x112x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26658,6 +27021,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26692,7 +27056,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26702,7 +27066,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26729,6 +27093,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26763,7 +27128,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26772,7 +27137,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_SS_TN +struct MMA_64x128x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26801,6 +27166,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26838,7 +27204,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26846,7 +27212,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26875,6 +27241,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26912,7 +27279,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26921,7 +27288,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_SS_TN +struct MMA_64x144x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26952,6 +27319,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26992,7 +27360,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27002,7 +27370,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27033,6 +27401,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27073,7 +27442,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27083,7 +27452,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_SS_TN +struct MMA_64x160x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27116,6 +27485,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27159,7 +27529,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27169,7 +27539,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27202,6 +27572,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27245,7 +27616,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27255,7 +27626,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_SS_TN +struct MMA_64x176x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27290,6 +27661,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27336,7 +27708,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27346,7 +27718,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27381,6 +27753,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27427,7 +27800,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27436,7 +27809,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_SS_TN +struct MMA_64x192x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27473,6 +27846,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27522,7 +27896,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27530,7 +27904,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27567,6 +27941,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27616,7 +27991,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27625,7 +28000,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_SS_TN +struct MMA_64x208x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27664,6 +28039,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27716,7 +28092,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27726,7 +28102,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27765,6 +28141,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27817,7 +28194,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27827,7 +28204,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_SS_TN +struct MMA_64x224x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27868,6 +28245,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27923,7 +28301,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27933,7 +28311,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27974,6 +28352,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28029,7 +28408,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28039,7 +28418,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_SS_TN +struct MMA_64x240x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28082,6 +28461,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28140,7 +28520,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28150,7 +28530,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28193,6 +28573,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28251,7 +28632,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28260,7 +28641,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_SS_TN +struct MMA_64x256x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28305,6 +28686,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28366,7 +28748,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28374,7 +28756,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28419,6 +28801,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28480,7 +28863,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28488,7 +28871,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_RS_TN +struct MMA_64x8x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28502,6 +28885,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28517,7 +28901,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28525,7 +28909,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28539,6 +28923,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28554,7 +28939,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28562,7 +28947,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_RS_TN +struct MMA_64x16x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28577,6 +28962,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28593,7 +28979,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28601,7 +28987,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28616,6 +29002,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28632,7 +29019,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28640,7 +29027,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_RS_TN +struct MMA_64x32x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28657,6 +29044,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28676,7 +29064,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28684,7 +29072,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28701,6 +29089,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28720,7 +29109,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28729,7 +29118,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_RS_TN +struct MMA_64x48x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28748,6 +29137,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28770,7 +29160,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28780,7 +29170,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28799,6 +29189,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28821,7 +29212,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28830,7 +29221,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_RS_TN +struct MMA_64x64x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28851,6 +29242,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28876,7 +29268,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28884,7 +29276,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28905,6 +29297,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28930,7 +29323,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28939,7 +29332,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_RS_TN +struct MMA_64x80x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28962,6 +29355,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28990,7 +29384,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29000,7 +29394,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29023,6 +29417,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29051,7 +29446,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29060,7 +29455,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_RS_TN +struct MMA_64x96x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29085,6 +29480,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29116,7 +29512,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29124,7 +29520,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29149,6 +29545,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29180,7 +29577,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29189,7 +29586,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_RS_TN +struct MMA_64x112x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29216,6 +29613,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29250,7 +29648,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29260,7 +29658,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29287,6 +29685,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29321,7 +29720,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29330,7 +29729,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_RS_TN +struct MMA_64x128x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29359,6 +29758,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29396,7 +29796,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29404,7 +29804,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29433,6 +29833,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29470,7 +29871,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29479,7 +29880,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_RS_TN +struct MMA_64x144x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29510,6 +29911,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29550,7 +29952,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29560,7 +29962,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29591,6 +29993,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29631,7 +30034,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29641,7 +30044,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_RS_TN +struct MMA_64x160x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29674,6 +30077,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29717,7 +30121,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29727,7 +30131,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29760,6 +30164,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29803,7 +30208,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29813,7 +30218,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_RS_TN +struct MMA_64x176x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29848,6 +30253,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29894,7 +30300,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29904,7 +30310,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29939,6 +30345,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29985,7 +30392,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29994,7 +30401,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_RS_TN +struct MMA_64x192x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30031,6 +30438,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30080,7 +30488,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30088,7 +30496,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30125,6 +30533,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30174,7 +30583,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30183,7 +30592,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_RS_TN +struct MMA_64x208x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30222,6 +30631,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30274,7 +30684,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30284,7 +30694,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30323,6 +30733,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30375,7 +30786,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30385,7 +30796,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_RS_TN +struct MMA_64x224x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30426,6 +30837,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30481,7 +30893,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30491,7 +30903,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30532,6 +30944,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30587,7 +31000,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30597,7 +31010,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_RS_TN +struct MMA_64x240x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30640,6 +31053,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30698,7 +31112,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30708,7 +31122,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30751,6 +31165,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30809,7 +31224,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30818,7 +31233,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_RS_TN +struct MMA_64x256x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30863,6 +31278,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30924,7 +31340,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30932,7 +31348,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30977,6 +31393,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31038,7 +31455,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31050,7 +31467,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E4M3_SS_TN +struct MMA_64x8x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31064,6 +31481,7 @@ struct SM90_64x8x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31079,7 +31497,7 @@ struct SM90_64x8x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31091,7 +31509,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E4M3_RS_TN +struct MMA_64x8x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31105,6 +31523,7 @@ struct SM90_64x8x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31120,7 +31539,7 @@ struct SM90_64x8x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31132,7 +31551,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E4M3_SS_TN +struct MMA_64x8x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31146,6 +31565,7 @@ struct SM90_64x8x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31161,7 +31581,7 @@ struct SM90_64x8x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31173,7 +31593,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E4M3_RS_TN +struct MMA_64x8x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31187,6 +31607,7 @@ struct SM90_64x8x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31202,7 +31623,7 @@ struct SM90_64x8x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31214,7 +31635,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E4M3_SS_TN +struct MMA_64x16x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31228,6 +31649,7 @@ struct SM90_64x16x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31243,7 +31665,7 @@ struct SM90_64x16x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31255,7 +31677,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E4M3_RS_TN +struct MMA_64x16x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31269,6 +31691,7 @@ struct SM90_64x16x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31284,7 +31707,7 @@ struct SM90_64x16x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31296,7 +31719,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E4M3_SS_TN +struct MMA_64x16x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31311,6 +31734,7 @@ struct SM90_64x16x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31327,7 +31751,7 @@ struct SM90_64x16x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31339,7 +31763,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E4M3_RS_TN +struct MMA_64x16x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31354,6 +31778,7 @@ struct SM90_64x16x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31370,7 +31795,7 @@ struct SM90_64x16x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31382,7 +31807,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E4M3_SS_TN +struct MMA_64x32x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31397,6 +31822,7 @@ struct SM90_64x32x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31413,7 +31839,7 @@ struct SM90_64x32x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31425,7 +31851,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E4M3_RS_TN +struct MMA_64x32x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31440,6 +31866,7 @@ struct SM90_64x32x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31456,7 +31883,7 @@ struct SM90_64x32x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31468,7 +31895,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E4M3_SS_TN +struct MMA_64x32x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31485,6 +31912,7 @@ struct SM90_64x32x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31504,7 +31932,7 @@ struct SM90_64x32x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31516,7 +31944,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E4M3_RS_TN +struct MMA_64x32x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31533,6 +31961,7 @@ struct SM90_64x32x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31552,7 +31981,7 @@ struct SM90_64x32x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31565,7 +31994,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E4M3_SS_TN +struct MMA_64x48x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31581,6 +32010,7 @@ struct SM90_64x48x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31599,7 +32029,7 @@ struct SM90_64x48x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31613,7 +32043,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E4M3_RS_TN +struct MMA_64x48x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31629,6 +32059,7 @@ struct SM90_64x48x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31647,7 +32078,7 @@ struct SM90_64x48x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31661,7 +32092,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E4M3_SS_TN +struct MMA_64x48x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31680,6 +32111,7 @@ struct SM90_64x48x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31702,7 +32134,7 @@ struct SM90_64x48x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31716,7 +32148,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E4M3_RS_TN +struct MMA_64x48x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31735,6 +32167,7 @@ struct SM90_64x48x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31757,7 +32190,7 @@ struct SM90_64x48x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31770,7 +32203,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E4M3_SS_TN +struct MMA_64x64x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31787,6 +32220,7 @@ struct SM90_64x64x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31806,7 +32240,7 @@ struct SM90_64x64x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31818,7 +32252,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E4M3_RS_TN +struct MMA_64x64x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31835,6 +32269,7 @@ struct SM90_64x64x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31854,7 +32289,7 @@ struct SM90_64x64x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31866,7 +32301,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E4M3_SS_TN +struct MMA_64x64x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31887,6 +32322,7 @@ struct SM90_64x64x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31912,7 +32348,7 @@ struct SM90_64x64x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31924,7 +32360,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E4M3_RS_TN +struct MMA_64x64x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31945,6 +32381,7 @@ struct SM90_64x64x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31970,7 +32407,7 @@ struct SM90_64x64x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31983,7 +32420,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E4M3_SS_TN +struct MMA_64x80x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32001,6 +32438,7 @@ struct SM90_64x80x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32022,7 +32460,7 @@ struct SM90_64x80x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32036,7 +32474,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E4M3_RS_TN +struct MMA_64x80x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32054,6 +32492,7 @@ struct SM90_64x80x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32075,7 +32514,7 @@ struct SM90_64x80x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32089,7 +32528,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E4M3_SS_TN +struct MMA_64x80x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32112,6 +32551,7 @@ struct SM90_64x80x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32140,7 +32580,7 @@ struct SM90_64x80x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32154,7 +32594,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E4M3_RS_TN +struct MMA_64x80x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32177,6 +32617,7 @@ struct SM90_64x80x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32205,7 +32646,7 @@ struct SM90_64x80x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32218,7 +32659,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E4M3_SS_TN +struct MMA_64x96x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32237,6 +32678,7 @@ struct SM90_64x96x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32259,7 +32701,7 @@ struct SM90_64x96x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32271,7 +32713,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E4M3_RS_TN +struct MMA_64x96x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32290,6 +32732,7 @@ struct SM90_64x96x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32312,7 +32755,7 @@ struct SM90_64x96x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32324,7 +32767,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E4M3_SS_TN +struct MMA_64x96x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32349,6 +32792,7 @@ struct SM90_64x96x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32380,7 +32824,7 @@ struct SM90_64x96x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32392,7 +32836,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E4M3_RS_TN +struct MMA_64x96x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32417,6 +32861,7 @@ struct SM90_64x96x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32448,7 +32893,7 @@ struct SM90_64x96x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32461,7 +32906,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E4M3_SS_TN +struct MMA_64x112x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32481,6 +32926,7 @@ struct SM90_64x112x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32505,7 +32951,7 @@ struct SM90_64x112x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32519,7 +32965,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E4M3_RS_TN +struct MMA_64x112x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32539,6 +32985,7 @@ struct SM90_64x112x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32563,7 +33010,7 @@ struct SM90_64x112x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32577,7 +33024,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E4M3_SS_TN +struct MMA_64x112x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32604,6 +33051,7 @@ struct SM90_64x112x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32638,7 +33086,7 @@ struct SM90_64x112x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32652,7 +33100,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E4M3_RS_TN +struct MMA_64x112x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32679,6 +33127,7 @@ struct SM90_64x112x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32713,7 +33162,7 @@ struct SM90_64x112x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32726,7 +33175,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E4M3_SS_TN +struct MMA_64x128x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32747,6 +33196,7 @@ struct SM90_64x128x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32772,7 +33222,7 @@ struct SM90_64x128x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32784,7 +33234,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E4M3_RS_TN +struct MMA_64x128x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32805,6 +33255,7 @@ struct SM90_64x128x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32830,7 +33281,7 @@ struct SM90_64x128x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32842,7 +33293,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E4M3_SS_TN +struct MMA_64x128x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32871,6 +33322,7 @@ struct SM90_64x128x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32908,7 +33360,7 @@ struct SM90_64x128x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32920,7 +33372,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E4M3_RS_TN +struct MMA_64x128x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32949,6 +33401,7 @@ struct SM90_64x128x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32986,7 +33439,7 @@ struct SM90_64x128x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32999,7 +33452,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E4M3_SS_TN +struct MMA_64x144x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33021,6 +33474,7 @@ struct SM90_64x144x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33048,7 +33502,7 @@ struct SM90_64x144x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33062,7 +33516,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E4M3_RS_TN +struct MMA_64x144x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33084,6 +33538,7 @@ struct SM90_64x144x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33111,7 +33566,7 @@ struct SM90_64x144x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33125,7 +33580,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E4M3_SS_TN +struct MMA_64x144x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33156,6 +33611,7 @@ struct SM90_64x144x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33196,7 +33652,7 @@ struct SM90_64x144x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33210,7 +33666,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E4M3_RS_TN +struct MMA_64x144x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33241,6 +33697,7 @@ struct SM90_64x144x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33281,7 +33738,7 @@ struct SM90_64x144x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33295,7 +33752,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E4M3_SS_TN +struct MMA_64x160x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33318,6 +33775,7 @@ struct SM90_64x160x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33346,7 +33804,7 @@ struct SM90_64x160x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33360,7 +33818,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E4M3_RS_TN +struct MMA_64x160x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33383,6 +33841,7 @@ struct SM90_64x160x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33411,7 +33870,7 @@ struct SM90_64x160x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33425,7 +33884,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E4M3_SS_TN +struct MMA_64x160x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33458,6 +33917,7 @@ struct SM90_64x160x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33501,7 +33961,7 @@ struct SM90_64x160x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33515,7 +33975,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E4M3_RS_TN +struct MMA_64x160x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33548,6 +34008,7 @@ struct SM90_64x160x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33591,7 +34052,7 @@ struct SM90_64x160x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33605,7 +34066,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E4M3_SS_TN +struct MMA_64x176x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33629,6 +34090,7 @@ struct SM90_64x176x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33659,7 +34121,7 @@ struct SM90_64x176x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33673,7 +34135,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E4M3_RS_TN +struct MMA_64x176x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33697,6 +34159,7 @@ struct SM90_64x176x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33727,7 +34190,7 @@ struct SM90_64x176x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33741,7 +34204,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E4M3_SS_TN +struct MMA_64x176x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33776,6 +34239,7 @@ struct SM90_64x176x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33822,7 +34286,7 @@ struct SM90_64x176x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33836,7 +34300,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E4M3_RS_TN +struct MMA_64x176x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33871,6 +34335,7 @@ struct SM90_64x176x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33917,7 +34382,7 @@ struct SM90_64x176x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33930,7 +34395,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E4M3_SS_TN +struct MMA_64x192x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33955,6 +34420,7 @@ struct SM90_64x192x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33986,7 +34452,7 @@ struct SM90_64x192x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33998,7 +34464,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E4M3_RS_TN +struct MMA_64x192x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34023,6 +34489,7 @@ struct SM90_64x192x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34054,7 +34521,7 @@ struct SM90_64x192x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34066,7 +34533,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E4M3_SS_TN +struct MMA_64x192x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34103,6 +34570,7 @@ struct SM90_64x192x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34152,7 +34620,7 @@ struct SM90_64x192x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34164,7 +34632,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E4M3_RS_TN +struct MMA_64x192x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34201,6 +34669,7 @@ struct SM90_64x192x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34250,7 +34719,7 @@ struct SM90_64x192x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34263,7 +34732,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E4M3_SS_TN +struct MMA_64x208x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34289,6 +34758,7 @@ struct SM90_64x208x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34322,7 +34792,7 @@ struct SM90_64x208x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34336,7 +34806,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E4M3_RS_TN +struct MMA_64x208x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34362,6 +34832,7 @@ struct SM90_64x208x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34395,7 +34866,7 @@ struct SM90_64x208x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34409,7 +34880,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E4M3_SS_TN +struct MMA_64x208x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34448,6 +34919,7 @@ struct SM90_64x208x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34500,7 +34972,7 @@ struct SM90_64x208x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34514,7 +34986,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E4M3_RS_TN +struct MMA_64x208x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34553,6 +35025,7 @@ struct SM90_64x208x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34605,7 +35078,7 @@ struct SM90_64x208x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34619,7 +35092,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E4M3_SS_TN +struct MMA_64x224x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34646,6 +35119,7 @@ struct SM90_64x224x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34680,7 +35154,7 @@ struct SM90_64x224x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34694,7 +35168,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E4M3_RS_TN +struct MMA_64x224x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34721,6 +35195,7 @@ struct SM90_64x224x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34755,7 +35230,7 @@ struct SM90_64x224x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34769,7 +35244,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E4M3_SS_TN +struct MMA_64x224x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34810,6 +35285,7 @@ struct SM90_64x224x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34865,7 +35341,7 @@ struct SM90_64x224x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34879,7 +35355,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E4M3_RS_TN +struct MMA_64x224x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34920,6 +35396,7 @@ struct SM90_64x224x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34975,7 +35452,7 @@ struct SM90_64x224x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34989,7 +35466,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E4M3_SS_TN +struct MMA_64x240x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35017,6 +35494,7 @@ struct SM90_64x240x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35053,7 +35531,7 @@ struct SM90_64x240x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35067,7 +35545,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E4M3_RS_TN +struct MMA_64x240x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35095,6 +35573,7 @@ struct SM90_64x240x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35131,7 +35610,7 @@ struct SM90_64x240x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35145,7 +35624,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E4M3_SS_TN +struct MMA_64x240x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35188,6 +35667,7 @@ struct SM90_64x240x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35246,7 +35726,7 @@ struct SM90_64x240x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35260,7 +35740,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E4M3_RS_TN +struct MMA_64x240x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35303,6 +35783,7 @@ struct SM90_64x240x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35361,7 +35842,7 @@ struct SM90_64x240x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35374,7 +35855,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E4M3_SS_TN +struct MMA_64x256x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35403,6 +35884,7 @@ struct SM90_64x256x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35440,7 +35922,7 @@ struct SM90_64x256x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35452,7 +35934,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E4M3_RS_TN +struct MMA_64x256x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35481,6 +35963,7 @@ struct SM90_64x256x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35518,7 +36001,7 @@ struct SM90_64x256x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35530,7 +36013,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E4M3_SS_TN +struct MMA_64x256x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35575,6 +36058,7 @@ struct SM90_64x256x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35636,7 +36120,7 @@ struct SM90_64x256x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35648,7 +36132,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E4M3_RS_TN +struct MMA_64x256x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35693,6 +36177,7 @@ struct SM90_64x256x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35754,7 +36239,7 @@ struct SM90_64x256x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35766,7 +36251,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E5M2_SS_TN +struct MMA_64x8x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35780,6 +36265,7 @@ struct SM90_64x8x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35795,7 +36281,7 @@ struct SM90_64x8x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35807,7 +36293,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E5M2_RS_TN +struct MMA_64x8x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35821,6 +36307,7 @@ struct SM90_64x8x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35836,7 +36323,7 @@ struct SM90_64x8x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35848,7 +36335,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E5M2_SS_TN +struct MMA_64x8x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35862,6 +36349,7 @@ struct SM90_64x8x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35877,7 +36365,7 @@ struct SM90_64x8x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35889,7 +36377,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E5M2_RS_TN +struct MMA_64x8x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35903,6 +36391,7 @@ struct SM90_64x8x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35918,7 +36407,7 @@ struct SM90_64x8x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35930,7 +36419,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E5M2_SS_TN +struct MMA_64x16x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35944,6 +36433,7 @@ struct SM90_64x16x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35959,7 +36449,7 @@ struct SM90_64x16x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35971,7 +36461,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E5M2_RS_TN +struct MMA_64x16x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35985,6 +36475,7 @@ struct SM90_64x16x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36000,7 +36491,7 @@ struct SM90_64x16x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36012,7 +36503,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E5M2_SS_TN +struct MMA_64x16x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36027,6 +36518,7 @@ struct SM90_64x16x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36043,7 +36535,7 @@ struct SM90_64x16x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36055,7 +36547,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E5M2_RS_TN +struct MMA_64x16x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36070,6 +36562,7 @@ struct SM90_64x16x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36086,7 +36579,7 @@ struct SM90_64x16x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36098,7 +36591,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E5M2_SS_TN +struct MMA_64x32x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36113,6 +36606,7 @@ struct SM90_64x32x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36129,7 +36623,7 @@ struct SM90_64x32x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36141,7 +36635,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E5M2_RS_TN +struct MMA_64x32x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36156,6 +36650,7 @@ struct SM90_64x32x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36172,7 +36667,7 @@ struct SM90_64x32x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36184,7 +36679,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E5M2_SS_TN +struct MMA_64x32x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36201,6 +36696,7 @@ struct SM90_64x32x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36220,7 +36716,7 @@ struct SM90_64x32x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36232,7 +36728,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E5M2_RS_TN +struct MMA_64x32x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36249,6 +36745,7 @@ struct SM90_64x32x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36268,7 +36765,7 @@ struct SM90_64x32x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36281,7 +36778,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E5M2_SS_TN +struct MMA_64x48x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36297,6 +36794,7 @@ struct SM90_64x48x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36315,7 +36813,7 @@ struct SM90_64x48x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36329,7 +36827,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E5M2_RS_TN +struct MMA_64x48x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36345,6 +36843,7 @@ struct SM90_64x48x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36363,7 +36862,7 @@ struct SM90_64x48x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36377,7 +36876,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E5M2_SS_TN +struct MMA_64x48x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36396,6 +36895,7 @@ struct SM90_64x48x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36418,7 +36918,7 @@ struct SM90_64x48x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36432,7 +36932,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E5M2_RS_TN +struct MMA_64x48x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36451,6 +36951,7 @@ struct SM90_64x48x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36473,7 +36974,7 @@ struct SM90_64x48x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36486,7 +36987,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E5M2_SS_TN +struct MMA_64x64x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36503,6 +37004,7 @@ struct SM90_64x64x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36522,7 +37024,7 @@ struct SM90_64x64x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36534,7 +37036,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E5M2_RS_TN +struct MMA_64x64x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36551,6 +37053,7 @@ struct SM90_64x64x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36570,7 +37073,7 @@ struct SM90_64x64x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36582,7 +37085,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E5M2_SS_TN +struct MMA_64x64x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36603,6 +37106,7 @@ struct SM90_64x64x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36628,7 +37132,7 @@ struct SM90_64x64x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36640,7 +37144,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E5M2_RS_TN +struct MMA_64x64x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36661,6 +37165,7 @@ struct SM90_64x64x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36686,7 +37191,7 @@ struct SM90_64x64x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36699,7 +37204,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E5M2_SS_TN +struct MMA_64x80x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36717,6 +37222,7 @@ struct SM90_64x80x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36738,7 +37244,7 @@ struct SM90_64x80x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36752,7 +37258,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E5M2_RS_TN +struct MMA_64x80x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36770,6 +37276,7 @@ struct SM90_64x80x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36791,7 +37298,7 @@ struct SM90_64x80x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36805,7 +37312,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E5M2_SS_TN +struct MMA_64x80x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36828,6 +37335,7 @@ struct SM90_64x80x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36856,7 +37364,7 @@ struct SM90_64x80x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36870,7 +37378,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E5M2_RS_TN +struct MMA_64x80x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36893,6 +37401,7 @@ struct SM90_64x80x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36921,7 +37430,7 @@ struct SM90_64x80x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36934,7 +37443,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E5M2_SS_TN +struct MMA_64x96x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36953,6 +37462,7 @@ struct SM90_64x96x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36975,7 +37485,7 @@ struct SM90_64x96x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36987,7 +37497,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E5M2_RS_TN +struct MMA_64x96x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37006,6 +37516,7 @@ struct SM90_64x96x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37028,7 +37539,7 @@ struct SM90_64x96x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37040,7 +37551,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E5M2_SS_TN +struct MMA_64x96x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37065,6 +37576,7 @@ struct SM90_64x96x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37096,7 +37608,7 @@ struct SM90_64x96x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37108,7 +37620,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E5M2_RS_TN +struct MMA_64x96x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37133,6 +37645,7 @@ struct SM90_64x96x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37164,7 +37677,7 @@ struct SM90_64x96x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37177,7 +37690,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E5M2_SS_TN +struct MMA_64x112x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37197,6 +37710,7 @@ struct SM90_64x112x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37221,7 +37735,7 @@ struct SM90_64x112x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37235,7 +37749,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E5M2_RS_TN +struct MMA_64x112x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37255,6 +37769,7 @@ struct SM90_64x112x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37279,7 +37794,7 @@ struct SM90_64x112x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37293,7 +37808,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E5M2_SS_TN +struct MMA_64x112x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37320,6 +37835,7 @@ struct SM90_64x112x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37354,7 +37870,7 @@ struct SM90_64x112x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37368,7 +37884,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E5M2_RS_TN +struct MMA_64x112x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37395,6 +37911,7 @@ struct SM90_64x112x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37429,7 +37946,7 @@ struct SM90_64x112x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37442,7 +37959,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E5M2_SS_TN +struct MMA_64x128x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37463,6 +37980,7 @@ struct SM90_64x128x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37488,7 +38006,7 @@ struct SM90_64x128x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37500,7 +38018,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E5M2_RS_TN +struct MMA_64x128x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37521,6 +38039,7 @@ struct SM90_64x128x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37546,7 +38065,7 @@ struct SM90_64x128x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37558,7 +38077,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E5M2_SS_TN +struct MMA_64x128x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37587,6 +38106,7 @@ struct SM90_64x128x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37624,7 +38144,7 @@ struct SM90_64x128x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37636,7 +38156,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E5M2_RS_TN +struct MMA_64x128x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37665,6 +38185,7 @@ struct SM90_64x128x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37702,7 +38223,7 @@ struct SM90_64x128x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37715,7 +38236,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E5M2_SS_TN +struct MMA_64x144x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37737,6 +38258,7 @@ struct SM90_64x144x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37764,7 +38286,7 @@ struct SM90_64x144x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37778,7 +38300,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E5M2_RS_TN +struct MMA_64x144x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37800,6 +38322,7 @@ struct SM90_64x144x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37827,7 +38350,7 @@ struct SM90_64x144x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37841,7 +38364,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E5M2_SS_TN +struct MMA_64x144x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37872,6 +38395,7 @@ struct SM90_64x144x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37912,7 +38436,7 @@ struct SM90_64x144x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37926,7 +38450,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E5M2_RS_TN +struct MMA_64x144x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37957,6 +38481,7 @@ struct SM90_64x144x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37997,7 +38522,7 @@ struct SM90_64x144x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38011,7 +38536,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E5M2_SS_TN +struct MMA_64x160x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38034,6 +38559,7 @@ struct SM90_64x160x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38062,7 +38588,7 @@ struct SM90_64x160x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38076,7 +38602,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E5M2_RS_TN +struct MMA_64x160x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38099,6 +38625,7 @@ struct SM90_64x160x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38127,7 +38654,7 @@ struct SM90_64x160x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38141,7 +38668,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E5M2_SS_TN +struct MMA_64x160x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38174,6 +38701,7 @@ struct SM90_64x160x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38217,7 +38745,7 @@ struct SM90_64x160x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38231,7 +38759,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E5M2_RS_TN +struct MMA_64x160x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38264,6 +38792,7 @@ struct SM90_64x160x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38307,7 +38836,7 @@ struct SM90_64x160x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38321,7 +38850,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E5M2_SS_TN +struct MMA_64x176x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38345,6 +38874,7 @@ struct SM90_64x176x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38375,7 +38905,7 @@ struct SM90_64x176x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38389,7 +38919,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E5M2_RS_TN +struct MMA_64x176x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38413,6 +38943,7 @@ struct SM90_64x176x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38443,7 +38974,7 @@ struct SM90_64x176x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38457,7 +38988,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E5M2_SS_TN +struct MMA_64x176x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38492,6 +39023,7 @@ struct SM90_64x176x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38538,7 +39070,7 @@ struct SM90_64x176x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38552,7 +39084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E5M2_RS_TN +struct MMA_64x176x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38587,6 +39119,7 @@ struct SM90_64x176x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38633,7 +39166,7 @@ struct SM90_64x176x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38646,7 +39179,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E5M2_SS_TN +struct MMA_64x192x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38671,6 +39204,7 @@ struct SM90_64x192x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38702,7 +39236,7 @@ struct SM90_64x192x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38714,7 +39248,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E5M2_RS_TN +struct MMA_64x192x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38739,6 +39273,7 @@ struct SM90_64x192x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38770,7 +39305,7 @@ struct SM90_64x192x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38782,7 +39317,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E5M2_SS_TN +struct MMA_64x192x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38819,6 +39354,7 @@ struct SM90_64x192x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38868,7 +39404,7 @@ struct SM90_64x192x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38880,7 +39416,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E5M2_RS_TN +struct MMA_64x192x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38917,6 +39453,7 @@ struct SM90_64x192x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38966,7 +39503,7 @@ struct SM90_64x192x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38979,7 +39516,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E5M2_SS_TN +struct MMA_64x208x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39005,6 +39542,7 @@ struct SM90_64x208x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39038,7 +39576,7 @@ struct SM90_64x208x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39052,7 +39590,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E5M2_RS_TN +struct MMA_64x208x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39078,6 +39616,7 @@ struct SM90_64x208x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39111,7 +39650,7 @@ struct SM90_64x208x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39125,7 +39664,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E5M2_SS_TN +struct MMA_64x208x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39164,6 +39703,7 @@ struct SM90_64x208x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39216,7 +39756,7 @@ struct SM90_64x208x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39230,7 +39770,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E5M2_RS_TN +struct MMA_64x208x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39269,6 +39809,7 @@ struct SM90_64x208x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39321,7 +39862,7 @@ struct SM90_64x208x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39335,7 +39876,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E5M2_SS_TN +struct MMA_64x224x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39362,6 +39903,7 @@ struct SM90_64x224x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39396,7 +39938,7 @@ struct SM90_64x224x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39410,7 +39952,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E5M2_RS_TN +struct MMA_64x224x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39437,6 +39979,7 @@ struct SM90_64x224x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39471,7 +40014,7 @@ struct SM90_64x224x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39485,7 +40028,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E5M2_SS_TN +struct MMA_64x224x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39526,6 +40069,7 @@ struct SM90_64x224x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39581,7 +40125,7 @@ struct SM90_64x224x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39595,7 +40139,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E5M2_RS_TN +struct MMA_64x224x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39636,6 +40180,7 @@ struct SM90_64x224x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39691,7 +40236,7 @@ struct SM90_64x224x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39705,7 +40250,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E5M2_SS_TN +struct MMA_64x240x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39733,6 +40278,7 @@ struct SM90_64x240x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39769,7 +40315,7 @@ struct SM90_64x240x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39783,7 +40329,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E5M2_RS_TN +struct MMA_64x240x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39811,6 +40357,7 @@ struct SM90_64x240x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39847,7 +40394,7 @@ struct SM90_64x240x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39861,7 +40408,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E5M2_SS_TN +struct MMA_64x240x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39904,6 +40451,7 @@ struct SM90_64x240x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39962,7 +40510,7 @@ struct SM90_64x240x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39976,7 +40524,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E5M2_RS_TN +struct MMA_64x240x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40019,6 +40567,7 @@ struct SM90_64x240x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40077,7 +40626,7 @@ struct SM90_64x240x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40090,7 +40639,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E5M2_SS_TN +struct MMA_64x256x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40119,6 +40668,7 @@ struct SM90_64x256x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40156,7 +40706,7 @@ struct SM90_64x256x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40168,7 +40718,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E5M2_RS_TN +struct MMA_64x256x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40197,6 +40747,7 @@ struct SM90_64x256x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40234,7 +40785,7 @@ struct SM90_64x256x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40246,7 +40797,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E5M2_SS_TN +struct MMA_64x256x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40291,6 +40842,7 @@ struct SM90_64x256x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40352,7 +40904,7 @@ struct SM90_64x256x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40364,7 +40916,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E5M2_RS_TN +struct MMA_64x256x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40409,6 +40961,7 @@ struct SM90_64x256x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40470,7 +41023,7 @@ struct SM90_64x256x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40482,7 +41035,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E4M3_SS_TN +struct MMA_64x8x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40496,6 +41049,7 @@ struct SM90_64x8x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40511,7 +41065,7 @@ struct SM90_64x8x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40523,7 +41077,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E4M3_RS_TN +struct MMA_64x8x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40537,6 +41091,7 @@ struct SM90_64x8x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40552,7 +41107,7 @@ struct SM90_64x8x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40564,7 +41119,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E4M3_SS_TN +struct MMA_64x8x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40578,6 +41133,7 @@ struct SM90_64x8x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40593,7 +41149,7 @@ struct SM90_64x8x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40605,7 +41161,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E4M3_RS_TN +struct MMA_64x8x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40619,6 +41175,7 @@ struct SM90_64x8x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40634,7 +41191,7 @@ struct SM90_64x8x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40646,7 +41203,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E4M3_SS_TN +struct MMA_64x16x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40660,6 +41217,7 @@ struct SM90_64x16x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40675,7 +41233,7 @@ struct SM90_64x16x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40687,7 +41245,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E4M3_RS_TN +struct MMA_64x16x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40701,6 +41259,7 @@ struct SM90_64x16x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40716,7 +41275,7 @@ struct SM90_64x16x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40728,7 +41287,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E4M3_SS_TN +struct MMA_64x16x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40743,6 +41302,7 @@ struct SM90_64x16x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40759,7 +41319,7 @@ struct SM90_64x16x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40771,7 +41331,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E4M3_RS_TN +struct MMA_64x16x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40786,6 +41346,7 @@ struct SM90_64x16x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40802,7 +41363,7 @@ struct SM90_64x16x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40814,7 +41375,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E4M3_SS_TN +struct MMA_64x32x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40829,6 +41390,7 @@ struct SM90_64x32x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40845,7 +41407,7 @@ struct SM90_64x32x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40857,7 +41419,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E4M3_RS_TN +struct MMA_64x32x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40872,6 +41434,7 @@ struct SM90_64x32x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40888,7 +41451,7 @@ struct SM90_64x32x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40900,7 +41463,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E4M3_SS_TN +struct MMA_64x32x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40917,6 +41480,7 @@ struct SM90_64x32x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40936,7 +41500,7 @@ struct SM90_64x32x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40948,7 +41512,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E4M3_RS_TN +struct MMA_64x32x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40965,6 +41529,7 @@ struct SM90_64x32x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40984,7 +41549,7 @@ struct SM90_64x32x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40997,7 +41562,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E4M3_SS_TN +struct MMA_64x48x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41013,6 +41578,7 @@ struct SM90_64x48x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41031,7 +41597,7 @@ struct SM90_64x48x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41045,7 +41611,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E4M3_RS_TN +struct MMA_64x48x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41061,6 +41627,7 @@ struct SM90_64x48x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41079,7 +41646,7 @@ struct SM90_64x48x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41093,7 +41660,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E4M3_SS_TN +struct MMA_64x48x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41112,6 +41679,7 @@ struct SM90_64x48x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41134,7 +41702,7 @@ struct SM90_64x48x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41148,7 +41716,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E4M3_RS_TN +struct MMA_64x48x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41167,6 +41735,7 @@ struct SM90_64x48x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41189,7 +41758,7 @@ struct SM90_64x48x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41202,7 +41771,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E4M3_SS_TN +struct MMA_64x64x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41219,6 +41788,7 @@ struct SM90_64x64x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41238,7 +41808,7 @@ struct SM90_64x64x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41250,7 +41820,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E4M3_RS_TN +struct MMA_64x64x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41267,6 +41837,7 @@ struct SM90_64x64x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41286,7 +41857,7 @@ struct SM90_64x64x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41298,7 +41869,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E4M3_SS_TN +struct MMA_64x64x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41319,6 +41890,7 @@ struct SM90_64x64x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41344,7 +41916,7 @@ struct SM90_64x64x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41356,7 +41928,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E4M3_RS_TN +struct MMA_64x64x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41377,6 +41949,7 @@ struct SM90_64x64x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41402,7 +41975,7 @@ struct SM90_64x64x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41415,7 +41988,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E4M3_SS_TN +struct MMA_64x80x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41433,6 +42006,7 @@ struct SM90_64x80x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41454,7 +42028,7 @@ struct SM90_64x80x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41468,7 +42042,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E4M3_RS_TN +struct MMA_64x80x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41486,6 +42060,7 @@ struct SM90_64x80x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41507,7 +42082,7 @@ struct SM90_64x80x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41521,7 +42096,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E4M3_SS_TN +struct MMA_64x80x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41544,6 +42119,7 @@ struct SM90_64x80x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41572,7 +42148,7 @@ struct SM90_64x80x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41586,7 +42162,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E4M3_RS_TN +struct MMA_64x80x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41609,6 +42185,7 @@ struct SM90_64x80x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41637,7 +42214,7 @@ struct SM90_64x80x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41650,7 +42227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E4M3_SS_TN +struct MMA_64x96x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41669,6 +42246,7 @@ struct SM90_64x96x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41691,7 +42269,7 @@ struct SM90_64x96x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41703,7 +42281,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E4M3_RS_TN +struct MMA_64x96x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41722,6 +42300,7 @@ struct SM90_64x96x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41744,7 +42323,7 @@ struct SM90_64x96x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41756,7 +42335,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E4M3_SS_TN +struct MMA_64x96x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41781,6 +42360,7 @@ struct SM90_64x96x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41812,7 +42392,7 @@ struct SM90_64x96x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41824,7 +42404,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E4M3_RS_TN +struct MMA_64x96x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41849,6 +42429,7 @@ struct SM90_64x96x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41880,7 +42461,7 @@ struct SM90_64x96x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41893,7 +42474,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E4M3_SS_TN +struct MMA_64x112x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41913,6 +42494,7 @@ struct SM90_64x112x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41937,7 +42519,7 @@ struct SM90_64x112x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41951,7 +42533,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E4M3_RS_TN +struct MMA_64x112x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41971,6 +42553,7 @@ struct SM90_64x112x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41995,7 +42578,7 @@ struct SM90_64x112x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42009,7 +42592,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E4M3_SS_TN +struct MMA_64x112x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42036,6 +42619,7 @@ struct SM90_64x112x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42070,7 +42654,7 @@ struct SM90_64x112x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42084,7 +42668,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E4M3_RS_TN +struct MMA_64x112x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42111,6 +42695,7 @@ struct SM90_64x112x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42145,7 +42730,7 @@ struct SM90_64x112x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42158,7 +42743,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E4M3_SS_TN +struct MMA_64x128x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42179,6 +42764,7 @@ struct SM90_64x128x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42204,7 +42790,7 @@ struct SM90_64x128x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42216,7 +42802,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E4M3_RS_TN +struct MMA_64x128x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42237,6 +42823,7 @@ struct SM90_64x128x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42262,7 +42849,7 @@ struct SM90_64x128x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42274,7 +42861,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E4M3_SS_TN +struct MMA_64x128x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42303,6 +42890,7 @@ struct SM90_64x128x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42340,7 +42928,7 @@ struct SM90_64x128x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42352,7 +42940,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E4M3_RS_TN +struct MMA_64x128x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42381,6 +42969,7 @@ struct SM90_64x128x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42418,7 +43007,7 @@ struct SM90_64x128x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42431,7 +43020,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E4M3_SS_TN +struct MMA_64x144x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42453,6 +43042,7 @@ struct SM90_64x144x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42480,7 +43070,7 @@ struct SM90_64x144x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42494,7 +43084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E4M3_RS_TN +struct MMA_64x144x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42516,6 +43106,7 @@ struct SM90_64x144x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42543,7 +43134,7 @@ struct SM90_64x144x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42557,7 +43148,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E4M3_SS_TN +struct MMA_64x144x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42588,6 +43179,7 @@ struct SM90_64x144x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42628,7 +43220,7 @@ struct SM90_64x144x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42642,7 +43234,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E4M3_RS_TN +struct MMA_64x144x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42673,6 +43265,7 @@ struct SM90_64x144x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42713,7 +43306,7 @@ struct SM90_64x144x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42727,7 +43320,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E4M3_SS_TN +struct MMA_64x160x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42750,6 +43343,7 @@ struct SM90_64x160x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42778,7 +43372,7 @@ struct SM90_64x160x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42792,7 +43386,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E4M3_RS_TN +struct MMA_64x160x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42815,6 +43409,7 @@ struct SM90_64x160x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42843,7 +43438,7 @@ struct SM90_64x160x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42857,7 +43452,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E4M3_SS_TN +struct MMA_64x160x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42890,6 +43485,7 @@ struct SM90_64x160x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42933,7 +43529,7 @@ struct SM90_64x160x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42947,7 +43543,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E4M3_RS_TN +struct MMA_64x160x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42980,6 +43576,7 @@ struct SM90_64x160x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43023,7 +43620,7 @@ struct SM90_64x160x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43037,7 +43634,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E4M3_SS_TN +struct MMA_64x176x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43061,6 +43658,7 @@ struct SM90_64x176x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43091,7 +43689,7 @@ struct SM90_64x176x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43105,7 +43703,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E4M3_RS_TN +struct MMA_64x176x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43129,6 +43727,7 @@ struct SM90_64x176x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43159,7 +43758,7 @@ struct SM90_64x176x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43173,7 +43772,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E4M3_SS_TN +struct MMA_64x176x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43208,6 +43807,7 @@ struct SM90_64x176x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43254,7 +43854,7 @@ struct SM90_64x176x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43268,7 +43868,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E4M3_RS_TN +struct MMA_64x176x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43303,6 +43903,7 @@ struct SM90_64x176x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43349,7 +43950,7 @@ struct SM90_64x176x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43362,7 +43963,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E4M3_SS_TN +struct MMA_64x192x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43387,6 +43988,7 @@ struct SM90_64x192x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43418,7 +44020,7 @@ struct SM90_64x192x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43430,7 +44032,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E4M3_RS_TN +struct MMA_64x192x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43455,6 +44057,7 @@ struct SM90_64x192x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43486,7 +44089,7 @@ struct SM90_64x192x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43498,7 +44101,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E4M3_SS_TN +struct MMA_64x192x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43535,6 +44138,7 @@ struct SM90_64x192x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43584,7 +44188,7 @@ struct SM90_64x192x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43596,7 +44200,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E4M3_RS_TN +struct MMA_64x192x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43633,6 +44237,7 @@ struct SM90_64x192x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43682,7 +44287,7 @@ struct SM90_64x192x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43695,7 +44300,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E4M3_SS_TN +struct MMA_64x208x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43721,6 +44326,7 @@ struct SM90_64x208x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43754,7 +44360,7 @@ struct SM90_64x208x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43768,7 +44374,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E4M3_RS_TN +struct MMA_64x208x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43794,6 +44400,7 @@ struct SM90_64x208x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43827,7 +44434,7 @@ struct SM90_64x208x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43841,7 +44448,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E4M3_SS_TN +struct MMA_64x208x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43880,6 +44487,7 @@ struct SM90_64x208x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43932,7 +44540,7 @@ struct SM90_64x208x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43946,7 +44554,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E4M3_RS_TN +struct MMA_64x208x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43985,6 +44593,7 @@ struct SM90_64x208x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44037,7 +44646,7 @@ struct SM90_64x208x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44051,7 +44660,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E4M3_SS_TN +struct MMA_64x224x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44078,6 +44687,7 @@ struct SM90_64x224x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44112,7 +44722,7 @@ struct SM90_64x224x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44126,7 +44736,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E4M3_RS_TN +struct MMA_64x224x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44153,6 +44763,7 @@ struct SM90_64x224x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44187,7 +44798,7 @@ struct SM90_64x224x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44201,7 +44812,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E4M3_SS_TN +struct MMA_64x224x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44242,6 +44853,7 @@ struct SM90_64x224x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44297,7 +44909,7 @@ struct SM90_64x224x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44311,7 +44923,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E4M3_RS_TN +struct MMA_64x224x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44352,6 +44964,7 @@ struct SM90_64x224x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44407,7 +45020,7 @@ struct SM90_64x224x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44421,7 +45034,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E4M3_SS_TN +struct MMA_64x240x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44449,6 +45062,7 @@ struct SM90_64x240x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44485,7 +45099,7 @@ struct SM90_64x240x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44499,7 +45113,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E4M3_RS_TN +struct MMA_64x240x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44527,6 +45141,7 @@ struct SM90_64x240x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44563,7 +45178,7 @@ struct SM90_64x240x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44577,7 +45192,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E4M3_SS_TN +struct MMA_64x240x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44620,6 +45235,7 @@ struct SM90_64x240x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44678,7 +45294,7 @@ struct SM90_64x240x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44692,7 +45308,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E4M3_RS_TN +struct MMA_64x240x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44735,6 +45351,7 @@ struct SM90_64x240x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44793,7 +45410,7 @@ struct SM90_64x240x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44806,7 +45423,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E4M3_SS_TN +struct MMA_64x256x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44835,6 +45452,7 @@ struct SM90_64x256x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44872,7 +45490,7 @@ struct SM90_64x256x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44884,7 +45502,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E4M3_RS_TN +struct MMA_64x256x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44913,6 +45531,7 @@ struct SM90_64x256x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44950,7 +45569,7 @@ struct SM90_64x256x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44962,7 +45581,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E4M3_SS_TN +struct MMA_64x256x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45007,6 +45626,7 @@ struct SM90_64x256x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45068,7 +45688,7 @@ struct SM90_64x256x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45080,7 +45700,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E4M3_RS_TN +struct MMA_64x256x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45125,6 +45745,7 @@ struct SM90_64x256x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45186,7 +45807,7 @@ struct SM90_64x256x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45198,7 +45819,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E5M2_SS_TN +struct MMA_64x8x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45212,6 +45833,7 @@ struct SM90_64x8x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45227,7 +45849,7 @@ struct SM90_64x8x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45239,7 +45861,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E5M2_RS_TN +struct MMA_64x8x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45253,6 +45875,7 @@ struct SM90_64x8x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45268,7 +45891,7 @@ struct SM90_64x8x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45280,7 +45903,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E5M2_SS_TN +struct MMA_64x8x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45294,6 +45917,7 @@ struct SM90_64x8x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45309,7 +45933,7 @@ struct SM90_64x8x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45321,7 +45945,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E5M2_RS_TN +struct MMA_64x8x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45335,6 +45959,7 @@ struct SM90_64x8x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45350,7 +45975,7 @@ struct SM90_64x8x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45362,7 +45987,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E5M2_SS_TN +struct MMA_64x16x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45376,6 +46001,7 @@ struct SM90_64x16x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45391,7 +46017,7 @@ struct SM90_64x16x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45403,7 +46029,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E5M2_RS_TN +struct MMA_64x16x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45417,6 +46043,7 @@ struct SM90_64x16x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45432,7 +46059,7 @@ struct SM90_64x16x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45444,7 +46071,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E5M2_SS_TN +struct MMA_64x16x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45459,6 +46086,7 @@ struct SM90_64x16x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45475,7 +46103,7 @@ struct SM90_64x16x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45487,7 +46115,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E5M2_RS_TN +struct MMA_64x16x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45502,6 +46130,7 @@ struct SM90_64x16x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45518,7 +46147,7 @@ struct SM90_64x16x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45530,7 +46159,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E5M2_SS_TN +struct MMA_64x32x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45545,6 +46174,7 @@ struct SM90_64x32x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45561,7 +46191,7 @@ struct SM90_64x32x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45573,7 +46203,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E5M2_RS_TN +struct MMA_64x32x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45588,6 +46218,7 @@ struct SM90_64x32x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45604,7 +46235,7 @@ struct SM90_64x32x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45616,7 +46247,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E5M2_SS_TN +struct MMA_64x32x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45633,6 +46264,7 @@ struct SM90_64x32x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45652,7 +46284,7 @@ struct SM90_64x32x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45664,7 +46296,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E5M2_RS_TN +struct MMA_64x32x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45681,6 +46313,7 @@ struct SM90_64x32x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45700,7 +46333,7 @@ struct SM90_64x32x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45713,7 +46346,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E5M2_SS_TN +struct MMA_64x48x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45729,6 +46362,7 @@ struct SM90_64x48x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45747,7 +46381,7 @@ struct SM90_64x48x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45761,7 +46395,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E5M2_RS_TN +struct MMA_64x48x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45777,6 +46411,7 @@ struct SM90_64x48x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45795,7 +46430,7 @@ struct SM90_64x48x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45809,7 +46444,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E5M2_SS_TN +struct MMA_64x48x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45828,6 +46463,7 @@ struct SM90_64x48x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45850,7 +46486,7 @@ struct SM90_64x48x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45864,7 +46500,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E5M2_RS_TN +struct MMA_64x48x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45883,6 +46519,7 @@ struct SM90_64x48x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45905,7 +46542,7 @@ struct SM90_64x48x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45918,7 +46555,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E5M2_SS_TN +struct MMA_64x64x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45935,6 +46572,7 @@ struct SM90_64x64x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45954,7 +46592,7 @@ struct SM90_64x64x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45966,7 +46604,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E5M2_RS_TN +struct MMA_64x64x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45983,6 +46621,7 @@ struct SM90_64x64x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46002,7 +46641,7 @@ struct SM90_64x64x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46014,7 +46653,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E5M2_SS_TN +struct MMA_64x64x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46035,6 +46674,7 @@ struct SM90_64x64x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46060,7 +46700,7 @@ struct SM90_64x64x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46072,7 +46712,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E5M2_RS_TN +struct MMA_64x64x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46093,6 +46733,7 @@ struct SM90_64x64x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46118,7 +46759,7 @@ struct SM90_64x64x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46131,7 +46772,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E5M2_SS_TN +struct MMA_64x80x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46149,6 +46790,7 @@ struct SM90_64x80x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46170,7 +46812,7 @@ struct SM90_64x80x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46184,7 +46826,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E5M2_RS_TN +struct MMA_64x80x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46202,6 +46844,7 @@ struct SM90_64x80x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46223,7 +46866,7 @@ struct SM90_64x80x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46237,7 +46880,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E5M2_SS_TN +struct MMA_64x80x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46260,6 +46903,7 @@ struct SM90_64x80x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46288,7 +46932,7 @@ struct SM90_64x80x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46302,7 +46946,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E5M2_RS_TN +struct MMA_64x80x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46325,6 +46969,7 @@ struct SM90_64x80x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46353,7 +46998,7 @@ struct SM90_64x80x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46366,7 +47011,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E5M2_SS_TN +struct MMA_64x96x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46385,6 +47030,7 @@ struct SM90_64x96x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46407,7 +47053,7 @@ struct SM90_64x96x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46419,7 +47065,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E5M2_RS_TN +struct MMA_64x96x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46438,6 +47084,7 @@ struct SM90_64x96x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46460,7 +47107,7 @@ struct SM90_64x96x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46472,7 +47119,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E5M2_SS_TN +struct MMA_64x96x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46497,6 +47144,7 @@ struct SM90_64x96x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46528,7 +47176,7 @@ struct SM90_64x96x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46540,7 +47188,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E5M2_RS_TN +struct MMA_64x96x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46565,6 +47213,7 @@ struct SM90_64x96x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46596,7 +47245,7 @@ struct SM90_64x96x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46609,7 +47258,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E5M2_SS_TN +struct MMA_64x112x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46629,6 +47278,7 @@ struct SM90_64x112x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46653,7 +47303,7 @@ struct SM90_64x112x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46667,7 +47317,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E5M2_RS_TN +struct MMA_64x112x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46687,6 +47337,7 @@ struct SM90_64x112x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46711,7 +47362,7 @@ struct SM90_64x112x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46725,7 +47376,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E5M2_SS_TN +struct MMA_64x112x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46752,6 +47403,7 @@ struct SM90_64x112x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46786,7 +47438,7 @@ struct SM90_64x112x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46800,7 +47452,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E5M2_RS_TN +struct MMA_64x112x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46827,6 +47479,7 @@ struct SM90_64x112x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46861,7 +47514,7 @@ struct SM90_64x112x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46874,7 +47527,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E5M2_SS_TN +struct MMA_64x128x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46895,6 +47548,7 @@ struct SM90_64x128x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46920,7 +47574,7 @@ struct SM90_64x128x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46932,7 +47586,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E5M2_RS_TN +struct MMA_64x128x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46953,6 +47607,7 @@ struct SM90_64x128x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46978,7 +47633,7 @@ struct SM90_64x128x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46990,7 +47645,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E5M2_SS_TN +struct MMA_64x128x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47019,6 +47674,7 @@ struct SM90_64x128x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47056,7 +47712,7 @@ struct SM90_64x128x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47068,7 +47724,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E5M2_RS_TN +struct MMA_64x128x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47097,6 +47753,7 @@ struct SM90_64x128x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47134,7 +47791,7 @@ struct SM90_64x128x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47147,7 +47804,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E5M2_SS_TN +struct MMA_64x144x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47169,6 +47826,7 @@ struct SM90_64x144x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47196,7 +47854,7 @@ struct SM90_64x144x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47210,7 +47868,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E5M2_RS_TN +struct MMA_64x144x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47232,6 +47890,7 @@ struct SM90_64x144x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47259,7 +47918,7 @@ struct SM90_64x144x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47273,7 +47932,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E5M2_SS_TN +struct MMA_64x144x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47304,6 +47963,7 @@ struct SM90_64x144x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47344,7 +48004,7 @@ struct SM90_64x144x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47358,7 +48018,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E5M2_RS_TN +struct MMA_64x144x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47389,6 +48049,7 @@ struct SM90_64x144x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47429,7 +48090,7 @@ struct SM90_64x144x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47443,7 +48104,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E5M2_SS_TN +struct MMA_64x160x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47466,6 +48127,7 @@ struct SM90_64x160x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47494,7 +48156,7 @@ struct SM90_64x160x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47508,7 +48170,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E5M2_RS_TN +struct MMA_64x160x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47531,6 +48193,7 @@ struct SM90_64x160x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47559,7 +48222,7 @@ struct SM90_64x160x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47573,7 +48236,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E5M2_SS_TN +struct MMA_64x160x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47606,6 +48269,7 @@ struct SM90_64x160x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47649,7 +48313,7 @@ struct SM90_64x160x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47663,7 +48327,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E5M2_RS_TN +struct MMA_64x160x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47696,6 +48360,7 @@ struct SM90_64x160x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47739,7 +48404,7 @@ struct SM90_64x160x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47753,7 +48418,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E5M2_SS_TN +struct MMA_64x176x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47777,6 +48442,7 @@ struct SM90_64x176x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47807,7 +48473,7 @@ struct SM90_64x176x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47821,7 +48487,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E5M2_RS_TN +struct MMA_64x176x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47845,6 +48511,7 @@ struct SM90_64x176x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47875,7 +48542,7 @@ struct SM90_64x176x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47889,7 +48556,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E5M2_SS_TN +struct MMA_64x176x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47924,6 +48591,7 @@ struct SM90_64x176x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47970,7 +48638,7 @@ struct SM90_64x176x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47984,7 +48652,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E5M2_RS_TN +struct MMA_64x176x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48019,6 +48687,7 @@ struct SM90_64x176x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48065,7 +48734,7 @@ struct SM90_64x176x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48078,7 +48747,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E5M2_SS_TN +struct MMA_64x192x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48103,6 +48772,7 @@ struct SM90_64x192x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48134,7 +48804,7 @@ struct SM90_64x192x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48146,7 +48816,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E5M2_RS_TN +struct MMA_64x192x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48171,6 +48841,7 @@ struct SM90_64x192x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48202,7 +48873,7 @@ struct SM90_64x192x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48214,7 +48885,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E5M2_SS_TN +struct MMA_64x192x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48251,6 +48922,7 @@ struct SM90_64x192x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48300,7 +48972,7 @@ struct SM90_64x192x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48312,7 +48984,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E5M2_RS_TN +struct MMA_64x192x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48349,6 +49021,7 @@ struct SM90_64x192x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48398,7 +49071,7 @@ struct SM90_64x192x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48411,7 +49084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E5M2_SS_TN +struct MMA_64x208x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48437,6 +49110,7 @@ struct SM90_64x208x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48470,7 +49144,7 @@ struct SM90_64x208x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48484,7 +49158,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E5M2_RS_TN +struct MMA_64x208x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48510,6 +49184,7 @@ struct SM90_64x208x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48543,7 +49218,7 @@ struct SM90_64x208x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48557,7 +49232,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E5M2_SS_TN +struct MMA_64x208x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48596,6 +49271,7 @@ struct SM90_64x208x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48648,7 +49324,7 @@ struct SM90_64x208x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48662,7 +49338,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E5M2_RS_TN +struct MMA_64x208x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48701,6 +49377,7 @@ struct SM90_64x208x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48753,7 +49430,7 @@ struct SM90_64x208x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48767,7 +49444,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E5M2_SS_TN +struct MMA_64x224x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48794,6 +49471,7 @@ struct SM90_64x224x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48828,7 +49506,7 @@ struct SM90_64x224x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48842,7 +49520,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E5M2_RS_TN +struct MMA_64x224x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48869,6 +49547,7 @@ struct SM90_64x224x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48903,7 +49582,7 @@ struct SM90_64x224x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48917,7 +49596,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E5M2_SS_TN +struct MMA_64x224x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48958,6 +49637,7 @@ struct SM90_64x224x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49013,7 +49693,7 @@ struct SM90_64x224x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49027,7 +49707,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E5M2_RS_TN +struct MMA_64x224x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49068,6 +49748,7 @@ struct SM90_64x224x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49123,7 +49804,7 @@ struct SM90_64x224x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49137,7 +49818,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E5M2_SS_TN +struct MMA_64x240x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49165,6 +49846,7 @@ struct SM90_64x240x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49201,7 +49883,7 @@ struct SM90_64x240x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49215,7 +49897,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E5M2_RS_TN +struct MMA_64x240x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49243,6 +49925,7 @@ struct SM90_64x240x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49279,7 +49962,7 @@ struct SM90_64x240x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49293,7 +49976,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E5M2_SS_TN +struct MMA_64x240x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49336,6 +50019,7 @@ struct SM90_64x240x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49394,7 +50078,7 @@ struct SM90_64x240x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49408,7 +50092,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E5M2_RS_TN +struct MMA_64x240x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49451,6 +50135,7 @@ struct SM90_64x240x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49509,7 +50194,7 @@ struct SM90_64x240x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49522,7 +50207,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E5M2_SS_TN +struct MMA_64x256x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49551,6 +50236,7 @@ struct SM90_64x256x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49588,7 +50274,7 @@ struct SM90_64x256x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49600,7 +50286,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E5M2_RS_TN +struct MMA_64x256x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49629,6 +50315,7 @@ struct SM90_64x256x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49666,7 +50353,7 @@ struct SM90_64x256x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49678,7 +50365,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E5M2_SS_TN +struct MMA_64x256x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49723,6 +50410,7 @@ struct SM90_64x256x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49784,7 +50472,7 @@ struct SM90_64x256x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49796,7 +50484,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E5M2_RS_TN +struct MMA_64x256x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49841,6 +50529,7 @@ struct SM90_64x256x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49902,11 +50591,14 @@ struct SM90_64x256x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + } // namespace cute diff --git a/include/cute/arch/mma_sm90_gmma_sparse.hpp b/include/cute/arch/mma_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000..d05e762e1f --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -0,0 +1,53789 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // GMMA::Major, etc. + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 61417d8360..3749a9c255 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -31,7 +31,6 @@ #pragma once #include - #include #if defined(__clang__) && defined(__CUDA__) @@ -254,6 +253,28 @@ explode(Fn fn, return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); } +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence, + PtrG&& g, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); +} + // // Utility for exploding tuples into functions // diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 20a0627627..dd6b4e52a0 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -30,16 +30,13 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include -#include - -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_constant, cute::is_integral +#include // cute::Copy_Traits +#include // cute::TiledMMA namespace cute { @@ -651,10 +648,12 @@ print(ThrCopy const& thr_copy) print(TiledCopy{}); } -template +// TiledCopy to LaTeX TikZ +template CUTE_HOST_DEVICE auto -print_latex(TiledCopy const& copy) +print_latex(TiledCopy const& copy, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); @@ -663,13 +662,15 @@ print_latex(TiledCopy const& copy) layoutD_MN, thrID_D); } -// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread +// MNK Copy Layout to LaTeX TikZ template + class LayoutD, class ThrIDD, + class TikzColorFn = TikzColor_TV> CUTE_HOST_DEVICE void print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); @@ -677,33 +678,17 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and assert(size<0>(S) == size<0>(D)); assert(size<1>(S) == size<1>(D)); - char const* latex_header = - "\\documentclass{standalone}\n" - "\\usepackage{tikz}\n" - "\\usetikzlibrary{external}\n" - "\\tikzexternalize\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}",}; - - // Header + // Commented prints printf("%% LayoutS: "); print(S); printf("\n"); printf("%% ThrIDS : "); print(TS); printf("\n"); printf("%% LayoutD: "); print(D); printf("\n"); printf("%% ThrIDD : "); print(TD); printf("\n\n"); - printf(latex_header); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // S starting at 0,0 for (int i = 0; i < size<0>(S); ++i) { @@ -712,12 +697,22 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and int val_idx = S(i,j) / size(TS); int thr_idx = TS(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), i, j, thr_idx, val_idx); } } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(size<0>(S)), int(size<1>(S))); + // S Labels + for (int i = 0, j = -1; i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int i = -1, j = 0; j < size<1>(S); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } // D starting at 0,size<1>(S)+3 for (int i = 0; i < size<0>(D); ++i) { @@ -726,30 +721,26 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and int val_idx = D(i,j) / size(TD); int thr_idx = TD(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), i, j + size<1>(S) + 3, thr_idx, val_idx); } } - - // S Labels - for (int i = 0, j = -1; i < size<0>(S); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int j = 0, i = -1; j < size<1>(S); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3)); // D Labels - for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { + for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); } - for (int j = 0, i = -1; j < size<1>(D); ++j) { + for (int i = -1, j = 0; j < size<1>(D); ++j) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); } } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm50.hpp b/include/cute/atom/copy_traits_sm50.hpp index 8be0ef7bba..7a693805e6 100644 --- a/include/cute/atom/copy_traits_sm50.hpp +++ b/include/cute/atom/copy_traits_sm50.hpp @@ -39,7 +39,7 @@ namespace cute { template <> -struct Copy_Traits +struct Copy_Traits { // Logical thread id to thread idx (one-thread) using ThrID = Layout<_32>; @@ -55,4 +55,21 @@ struct Copy_Traits using RefLayout = SrcLayout; }; +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_32, _2>>, + Stride,Stride< _1, _256>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index ad4f8675b5..54f76073b1 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -450,7 +450,9 @@ make_im2col_tma_copy_desc( CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; CUtensorMapL2promotion tma_l2Promotion = to_CUtensorMapL2promotion(aux_params.l2promo_); CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); - CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle)); + TMA::SmemSwizzleBits swizzle_bits = detail::get_tma_swizzle_bits(smem_swizzle); + TMA::SmemSwizzleBase swizzle_base = detail::get_tma_swizzle_base(smem_swizzle); + CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( &tma_desc, @@ -636,11 +638,11 @@ make_tma_atom_im2col(CopyOp, auto range_c = size<0,0>(tma_layout_vt); auto range_whdn = size<0,1>(tma_layout_vt); - Tensor gtensor_cwhdn = make_tensor(gtensor.data(), - flatten(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.layout()), - basis_get(stride<0,1>(tma_layout_vt), gtensor.layout())))); - + flatten(make_layout(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,0>(tma_layout_vt), gtensor.stride())), + make_layout(basis_get(stride<0,1>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.stride()))))); auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( gtensor_cwhdn, range_c, diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 2238c41897..3738cc3962 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -41,6 +41,7 @@ #include #include + #include namespace cute @@ -241,15 +242,22 @@ struct Copy_Traits // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar CUTE_HOST_DEVICE constexpr Copy_Traits - with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { - return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask}}; + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; } // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) CUTE_HOST_DEVICE constexpr Copy_Traits - with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { - return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask}}; + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; } // Generate the TMA coord tensor @@ -287,7 +295,8 @@ struct Copy_Traits tuple< TmaDescriptor const*, uint64_t*, // smem mbarrier - uint16_t // multicast mask + uint16_t, // multicast mask + uint64_t // cache hint > const opargs_; }; @@ -684,8 +693,10 @@ construct_tma_gbasis(Tensor const& gtensor, // The origin // TMA parameter checking // - CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), - "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + // CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + // "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) == size(cta_v_map), + "TMA requires CTA_Tile and SLayout top-level size equivalence."); #if 0 print("gtensor : "); print(gtensor); print("\n"); @@ -983,7 +994,9 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; // TMA smem swizzle type - CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); + TMA::SmemSwizzleBits swizzle_bits = get_tma_swizzle_bits(swizzle); + TMA::SmemSwizzleBase swizzle_base = get_tma_swizzle_base(swizzle); + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tma_desc, tma_format, diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp index bb44a8353d..3286e72b36 100644 --- a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -68,4 +68,26 @@ get_tma_swizzle_bits(Layout const& layout) return get_tma_swizzle_bits(get_swizzle_portion(layout)); } +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBase +get_tma_swizzle_base(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B; + } + else { + static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); + } +} + +template +TMA::SmemSwizzleBase +get_tma_swizzle_base(Layout const& layout) +{ + return get_tma_swizzle_base(get_swizzle_portion(layout)); +} + } // namespace cute::detail diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 2358dd568f..bf40827436 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -45,11 +45,12 @@ template struct MMA_Atom : MMA_Atom> {}; -template -struct MMA_Atom> - : MMA_Traits +template +struct MMA_Atom> + : MMA_Traits { - using Traits = MMA_Traits; + using MMA_Op = MMAOperation; + using Traits = MMA_Traits; // Element value types from the MMA_Traits using ValTypeD = typename Traits::ValTypeD; @@ -331,7 +332,7 @@ struct TiledMMA : MMA_Atom make_layout(size<2>(AtomShape_MNK{}))); auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) - // Transform the Atom mode from (N,K) to (Thr,Val) + // Transform the Atom mode from (M,K) to (Thr,Val) auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) // Tile the tensor for the Thread @@ -733,18 +734,22 @@ print(ThrMMA const& thr_mma) print(static_cast(thr_mma)); } -template +// MMA Atom to LaTeX TikZ +template CUTE_HOST_DEVICE void -print_latex(MMA_Atom const& mma_atom) +print_latex(MMA_Atom const& mma_atom, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { print_latex(make_tiled_mma(mma_atom)); } -template +// TiledMMA to LaTeX TikZ +template CUTE_HOST_DEVICE void -print_latex(TiledMMA const& mma) +print_latex(TiledMMA const& mma, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { auto layout_and_thrid_C = mma.get_layoutC_MN(); auto layoutC_MN = get<0>(layout_and_thrid_C); @@ -763,71 +768,17 @@ print_latex(TiledMMA const& mma) layoutB_NK, thrID_B); } -// MNK MMA Layout to console printer -template -CUTE_HOST_DEVICE -void -print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - assert(size<0>(A) == size<0>(C)); - assert(size<0>(B) == size<1>(C)); - assert(size<1>(A) == size<1>(B)); - - int a_width = size<1>(A) * 6 + 4; - - // Print out B (white-shifted) k-by-n - for (int k = 0; k < size<1>(B); ++k) { - // Header - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n"); - // Values - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); - printf("|\n"); - } - // Footer - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n\n"); - - // Print out A m-by-k and C m-by-n - for (int m = 0; m < size<0>(A); ++m) { - // Header - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); - // Values - for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); - printf("| "); - for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); - printf("|\n"); - } - // Footer - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); -} - -// MNK MMA Layout to Latex TIKZ -- 8-value color coded by thread +// MNK MMA Layout to LaTeX TikZ template + class LayoutB, class ThrIDB, + class TikzColorFn = TikzColor_TV> CUTE_HOST_DEVICE void print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); @@ -837,35 +788,18 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and assert(size<0>(B) == size<1>(C)); assert(size<1>(A) == size<1>(B)); - char const* latex_header = - "\\documentclass{standalone}\n" - "\\usepackage{tikz}\n" - "\\usetikzlibrary{external}\n" - "\\tikzexternalize\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - - // Header + // Commented prints printf("%% LayoutC: "); print(C); printf("\n"); printf("%% ThrIDC : "); print(TC); printf("\n"); printf("%% LayoutA: "); print(A); printf("\n"); printf("%% ThrIDA : "); print(TA); printf("\n"); printf("%% LayoutB: "); print(B); printf("\n"); printf("%% ThrIDB : "); print(TB); printf("\n\n"); - - printf(latex_header); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // C starting at 0,0 for (int m = 0; m < size<0>(C); ++m) { @@ -874,12 +808,15 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and int val_idx = C(m,n) / size(TC); int thr_idx = TC(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), m, n, thr_idx, val_idx); } } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(size<0>(C)), int(size<1>(C))); // A starting at 0,-size<1>(A)-1 for (int m = 0; m < size<0>(A); ++m) { @@ -888,12 +825,22 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and int val_idx = A(m,k) / size(TA); int thr_idx = TA(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), m, k-1-size<1>(A), thr_idx, val_idx); } } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(-size<1>(A)-1), int(size<0>(A)), -1); + // A labels + for (int m = 0, k = -1; m < size<0>(A); ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); + } + for (int m = -1, k = 0; k < size<1>(A); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); + } // B starting at -size<1>(B)-1,0 for (int n = 0; n < size<0>(B); ++n) { @@ -902,30 +849,82 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and int val_idx = B(n,k) / size(TB); int thr_idx = TB(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), k-1-size<1>(B), n, thr_idx, val_idx); } } - - // A labels - for (int m = 0, k = -1; m < size<0>(A); ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); - } - for (int k = 0, m = -1; k < size<1>(A); ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); - } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + int(-size<1>(B)-1), 0, -1, int(size<0>(B))); // B labels - for (int n = 0, k = -1; n < size<0>(B); ++n) { + for (int n = 0, k = -1; n < size<0>(B); ++n) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); } - for (int k = 0, n = -1; k < size<1>(B); ++k) { + for (int n = -1, k = 0; k < size<1>(B); ++k) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +// MNK MMA Layout to console printer +template +CUTE_HOST_DEVICE +void +print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + int a_width = size<1>(A) * 6 + 4; + + // Print out B (white-shifted) k-by-n + for (int k = 0; k < size<1>(B); ++k) { + // Header + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n"); + // Values + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); + printf("|\n"); + } + // Footer + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n\n"); + + // Print out A m-by-k and C m-by-n + for (int m = 0; m < size<0>(A); ++m) { + // Header + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); + // Values + for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); + printf("| "); + for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); + printf("|\n"); + } + // Footer + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); } // MNK MMA Layout to SVG -- 8-value color coded by thread diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 8b9ac73642..0994698a87 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -30,23 +30,14 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // cute::Tensor +#include // cute::is_rmem +#include // cute::UniversalFMA +#include // cute::detail::explode namespace cute { -namespace detail { - -template -struct supports_output_scaling { static constexpr bool value = false; }; - -template -struct supports_output_scaling().accumulate_)>> { static constexpr bool value = true; }; - -} // end namespace detail - /** * concept MMA_Traits * { @@ -99,17 +90,27 @@ struct MMA_Traits> using CLayout = Layout>; }; +// Extract an MMA_Op from an MMA_Traits +template +struct MMA_Op {}; + +template +struct MMA_Op> { + using type = MMA_Op_Arg; +}; + // // Generic mma_unpack for any MMA_Traits // -template CUTE_HOST_DEVICE constexpr void -mma_unpack(MMA_Traits const& traits, +mma_unpack(AnyMMATraits const& traits, Tensor & D, Tensor const& A, Tensor const& B, @@ -121,87 +122,47 @@ mma_unpack(MMA_Traits const& traits, static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); // Register value types from the MMA_Operation register arrays + using MMA_Op = typename MMA_Op::type; using RegTypeD = typename remove_extent::type; using RegTypeA = typename remove_extent::type; using RegTypeB = typename remove_extent::type; using RegTypeC = typename remove_extent::type; - using MMATraits = MMA_Traits; - [[maybe_unused]] constexpr int RegNumD = extent::value; + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + + constexpr int RegNumD = extent::value; constexpr int RegNumA = extent::value; constexpr int RegNumB = extent::value; constexpr int RegNumC = extent::value; - Tensor rA = recast(A); - Tensor rB = recast(B); - CUTE_STATIC_ASSERT_V(size(rA) == Int{}); CUTE_STATIC_ASSERT_V(size(rB) == Int{}); - - if constexpr (is_same::value) - { - static_assert(is_same::value, "GMMA C and D value_type must match."); - static_assert(is_same::value, "GMMA C and D layouts must match."); - // assert((void*)&C == (void*)&D); - - Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D - - //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); - - if constexpr (detail::supports_output_scaling::value) { - detail::explode(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}, - &(traits.accumulate_), seq<0>{}); - } - else { - detail::explode(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); - } - } - else { - Tensor rD = recast(D); - Tensor rC = recast(C); - - CUTE_STATIC_ASSERT_V(size(rD) == Int{}); - CUTE_STATIC_ASSERT_V(size(rC) == Int{}); - if constexpr (detail::supports_output_scaling::value) { - detail::explode(MMA_Op::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}, - &(traits.accumulate_), seq<0>{}); - } - else { - detail::explode(MMA_Op::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); - } - } + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); } -// // Accept mutable temporaries -// - -template CUTE_HOST_DEVICE constexpr void -mma_unpack(MMA_Traits const& traits, - Tensor && D, - Tensor const& A, - Tensor const& B, - Tensor const& C) +mma_unpack(AnyMMATraits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) { mma_unpack(traits, D, A, B, C); } diff --git a/include/cute/atom/mma_traits_sm90.hpp b/include/cute/atom/mma_traits_sm90.hpp index 437af27b21..b2ced3f878 100644 --- a/include/cute/atom/mma_traits_sm90.hpp +++ b/include/cute/atom/mma_traits_sm90.hpp @@ -41,6 +41,8 @@ namespace cute { //////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// /////////////////////////////////////////////////////////////////////////////// +using SM90_16x8x4_F64F64F64F64_TN = SM90::MMA_16x8x4_F64F64F64F64_TN; + template <> struct MMA_Traits { @@ -59,6 +61,8 @@ struct MMA_Traits Stride,Stride<_16,_8>>>; }; +using SM90_16x8x8_F64F64F64F64_TN = SM90::MMA_16x8x8_F64F64F64F64_TN; + template <> struct MMA_Traits { @@ -77,6 +81,8 @@ struct MMA_Traits Stride,Stride<_16,_8>>>; }; +using SM90_16x8x16_F64F64F64F64_TN = SM90::MMA_16x8x16_F64F64F64F64_TN; + template <> struct MMA_Traits { @@ -99,9 +105,11 @@ struct MMA_Traits //////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// /////////////////////////////////////////////////////////////////////////////////// +using SM90_16x8x4_C64C64C64C64_TN = SM90::MMA_16x8x4_C64C64C64C64_TN; + template <> struct MMA_Traits - : MMA_Traits + : MMA_Traits { using ValTypeD = complex; using ValTypeA = complex; @@ -109,9 +117,11 @@ struct MMA_Traits using ValTypeC = complex; }; +using SM90_16x8x8_C64C64C64C64_TN = SM90::MMA_16x8x8_C64C64C64C64_TN; + template <> struct MMA_Traits - : MMA_Traits + : MMA_Traits { using ValTypeD = complex; using ValTypeA = complex; @@ -119,9 +129,11 @@ struct MMA_Traits using ValTypeC = complex; }; +using SM90_16x8x16_C64C64C64C64_TN = SM90::MMA_16x8x16_C64C64C64C64_TN; + template <> struct MMA_Traits - : MMA_Traits + : MMA_Traits { using ValTypeD = complex; using ValTypeA = complex; diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index e59bbeefc2..74f3d64601 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -30,10 +30,15 @@ **************************************************************************************************/ #pragma once -#include -#include - -#include +#include // cute::smem_ptr_flag +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90_64x8x16_F16F16F16_SS, etc +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static namespace cute { @@ -60,7 +65,7 @@ warpgroup_fence_operand(Tensor& frg) { } } -namespace GMMA { +namespace SM90::GMMA { /////////////////////////////////////////// // Common layouts for GMMA Shared Memory // @@ -99,20 +104,20 @@ template using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); // With GMMA::Major param -template -using Layout_INTER_Atom = typename conditional +using Layout_INTER_Atom = typename conditional, Layout_K_INTER_Atom>::type; -template -using Layout_SW32_Atom = typename conditional +using Layout_SW32_Atom = typename conditional, Layout_K_SW32_Atom>::type; -template -using Layout_SW64_Atom = typename conditional +using Layout_SW64_Atom = typename conditional, Layout_K_SW64_Atom>::type; -template -using Layout_SW128_Atom = typename conditional +using Layout_SW128_Atom = typename conditional, Layout_K_SW128_Atom>::type; @@ -188,7 +193,7 @@ layout_type(Tensor> const&) * auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); * is guaranteed to be accepted by make_gmma_desc for appropriate value_type. */ -template +template CUTE_HOST_DEVICE constexpr GmmaDescriptor make_gmma_desc(Tensor const& tensor) @@ -203,7 +208,7 @@ make_gmma_desc(Tensor const& tensor) GmmaDescriptor desc; // Layout type - constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); + constexpr LayoutType LAYOUT_TYPE = layout_type(u128_tensor); desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); // Start address (4LSB not included) @@ -214,12 +219,12 @@ make_gmma_desc(Tensor const& tensor) desc.bitfield.base_offset_ = base_offset; // LayoutType meta - constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : - LAYOUT_TYPE == GMMA::LayoutType::B32 ? 2 : - LAYOUT_TYPE == GMMA::LayoutType::B64 ? 4 : - LAYOUT_TYPE == GMMA::LayoutType::B128 ? 8 : -1; + constexpr int W = LAYOUT_TYPE == LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == LayoutType::B32 ? 2 : + LAYOUT_TYPE == LayoutType::B64 ? 4 : + LAYOUT_TYPE == LayoutType::B128 ? 8 : -1; - if constexpr (MajorMode == GMMA::Major::MN) + if constexpr (MajorMode == Major::MN) { /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form * @@ -228,8 +233,10 @@ make_gmma_desc(Tensor const& tensor) * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) */ - static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{}, // K size - "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits."); + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{} || // A and B in dense MMA + size<1>(u128_tensor) == Int<(128 / cute::sizeof_bits::value)>{} || // A in sparse MMA + size<1>(u128_tensor) == Int<(512 / cute::sizeof_bits::value)>{}, // B in sparse MMA + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits for dense or (128|512)/sizeof_bits for sparse."); // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); @@ -239,7 +246,7 @@ make_gmma_desc(Tensor const& tensor) CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); // Check canonical mode strides constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); - constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); constexpr uint32_t expected_stride_10 = W; @@ -249,10 +256,10 @@ make_gmma_desc(Tensor const& tensor) constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); - desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; - desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_11 : stride_01; } - else if constexpr (MajorMode == GMMA::Major::K) + else if constexpr (MajorMode == Major::K) { /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form * @@ -263,8 +270,8 @@ make_gmma_desc(Tensor const& tensor) */ CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); - CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{}, // K size - "Not a canonical GMMA_K Layout: Expected K-size 2 (in units of uint128_t)."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{} || size<1>(u128_tensor) == Int<4>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 for dense or 4 for sparse (in units of uint128_t)."); // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); @@ -277,7 +284,7 @@ make_gmma_desc(Tensor const& tensor) constexpr uint32_t expected_stride_00 = W; static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); - constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) @@ -286,7 +293,7 @@ make_gmma_desc(Tensor const& tensor) desc.bitfield.stride_byte_offset_ = stride_01; desc.bitfield.leading_byte_offset_ = stride_10; } else { - static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); + static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!"); } #if 0 @@ -357,21 +364,21 @@ print(DescriptorIterator) { // The GMMA Traits below have custom fragment type flags for their smem desc tensors. // These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. -template +template struct smem_desc : DescriptorIterator {}; -} // end namespace GMMA +} // end namespace SM90::GMMA // Customization point for creating a GMMA::smem_desc Tensor -template -struct MakeTensor> +template +struct MakeTensor> { template CUTE_HOST_DEVICE constexpr auto operator()(Tensor const& smem_tensor) { static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); - return make_tensor(GMMA::DescriptorIterator{GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); } }; @@ -380,7 +387,58 @@ struct MakeTensor> //////////////////////////// MMA_TRAITS /////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -namespace GMMA { +namespace SM90::GMMA { + +// +// Specialized mma_unpack implementation for SM90 GMMA instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + // SM90 GMMA take three arguments rather than four, try to assert C and D are aliased + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} // Accumulator layouts using CLayout_64x8 = Layout,Shape < _2,_2>>, @@ -392,7 +450,7 @@ using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x48 = Layout,Shape < _2,_2, _6>>, +using CLayout_64x48 = Layout,Shape < _2,_2, _6>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, @@ -404,31 +462,31 @@ using CLayout_64x80 = Layout,Shape < _2,_2, _10>>, using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x112 = Layout,Shape < _2,_2, Int<14>>>, +using CLayout_64x112 = Layout,Shape < _2,_2, Int<14>>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x144 = Layout,Shape < _2,_2, Int<18>>>, +using CLayout_64x144 = Layout,Shape < _2,_2, Int<18>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x160 = Layout,Shape < _2,_2, Int<20>>>, +using CLayout_64x160 = Layout,Shape < _2,_2, Int<20>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x176 = Layout,Shape < _2,_2, Int<22>>>, +using CLayout_64x176 = Layout,Shape < _2,_2, Int<22>>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x208 = Layout,Shape < _2,_2, Int<26>>>, +using CLayout_64x208 = Layout,Shape < _2,_2, Int<26>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x224 = Layout,Shape < _2,_2, Int<28>>>, +using CLayout_64x224 = Layout,Shape < _2,_2, Int<28>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x240 = Layout,Shape < _2,_2, Int<30>>>, +using CLayout_64x240 = Layout,Shape < _2,_2, Int<30>>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, @@ -438,19 +496,33 @@ using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, using ALayout_64x8 = Layout,Shape < _2, _2>>, Stride,Stride< _8,_256>>>; -// Register source layout for 16-bit value types -using ALayout_64x16 = CLayout_64x16; +// Register source layout for 16-bit (sparse 32-bit) value types +using ALayout_64x16 = CLayout_64x16; -// Register source layout for 8-bit value types -using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, - Stride,Stride<_64,_8,_1024>>>; +// Register source layout for 8-bit (sparse 16-bit) value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Register source layout for sparse 8-bit value types +using ALayout_64x64 = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_64,_8,_2048>>>; // Shared memory source layouts for any value type template using ABLayout = Layout,Int>>, Stride< _0,Stride< _1,Int>>>; -} // namespace GMMA +} // end namespace SM90::GMMA + +using namespace SM90; + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_SS = SM90::GMMA::MMA_64x8x16_F16F16F16_SS; template struct MMA_Traits> @@ -474,6 +546,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_RS = SM90::GMMA::MMA_64x8x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -495,6 +576,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_SS = SM90::GMMA::MMA_64x16x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -517,6 +607,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_RS = SM90::GMMA::MMA_64x16x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -538,6 +637,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_SS = SM90::GMMA::MMA_64x32x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -560,6 +668,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_RS = SM90::GMMA::MMA_64x32x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -582,6 +699,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -606,6 +732,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -628,6 +763,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -650,6 +794,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -672,6 +825,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -696,6 +858,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -718,6 +889,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -740,6 +920,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -762,6 +951,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -786,6 +984,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -808,6 +1015,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -830,6 +1046,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -852,6 +1077,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -876,6 +1110,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -899,6 +1142,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -923,6 +1175,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -946,6 +1207,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -970,6 +1240,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -992,6 +1271,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1014,6 +1302,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1036,6 +1333,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1060,6 +1366,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1083,6 +1398,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1107,6 +1431,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1130,6 +1463,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1154,6 +1496,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1176,6 +1527,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1198,6 +1558,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1219,6 +1588,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1241,6 +1619,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1262,6 +1649,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1284,6 +1680,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1305,6 +1710,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1327,6 +1741,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1349,6 +1772,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1373,6 +1805,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1395,6 +1836,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1417,6 +1867,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1439,6 +1898,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1463,6 +1931,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1485,6 +1962,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1507,6 +1993,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1529,6 +2024,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1553,6 +2057,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1575,6 +2088,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1597,6 +2119,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1619,6 +2150,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1643,6 +2183,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1666,6 +2215,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1690,6 +2248,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1713,6 +2280,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1737,6 +2313,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1759,6 +2344,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1781,6 +2375,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1803,6 +2406,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1827,6 +2439,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1850,6 +2471,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1874,6 +2504,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1897,6 +2536,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1921,6 +2569,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1943,6 +2600,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1965,6 +2631,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1986,6 +2661,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2008,6 +2692,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2029,6 +2722,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2051,6 +2753,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2072,6 +2783,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2094,6 +2814,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2116,6 +2845,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2140,6 +2878,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2162,6 +2909,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2184,6 +2940,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2206,6 +2971,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2230,6 +3004,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2252,6 +3035,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2274,6 +3066,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2296,6 +3097,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2320,6 +3130,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2342,6 +3161,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2364,6 +3192,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2386,6 +3223,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2410,6 +3256,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2433,6 +3288,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2457,6 +3321,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2480,6 +3353,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2504,6 +3386,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2526,6 +3417,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2548,6 +3448,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2570,6 +3479,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2594,6 +3512,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2617,6 +3544,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2641,6 +3577,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2664,6 +3609,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2688,6 +3642,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2710,6 +3673,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2732,6 +3704,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2753,6 +3734,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2775,6 +3763,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2796,6 +3791,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2818,6 +3820,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2839,6 +3848,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2861,6 +3877,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2883,6 +3906,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2907,6 +3937,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2929,6 +3966,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2951,6 +3995,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2973,6 +4024,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2997,6 +4055,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3019,6 +4084,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3041,6 +4113,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3063,6 +4142,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3087,6 +4173,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3109,6 +4202,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3131,6 +4231,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3153,6 +4260,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3177,6 +4291,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3200,6 +4321,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3224,7 +4352,14 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; + +template struct MMA_Traits> { using ValTypeD = float; @@ -3247,6 +4382,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3271,6 +4413,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3293,6 +4442,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3315,6 +4471,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3337,6 +4500,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3361,6 +4531,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3384,6 +4561,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3408,6 +4592,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3431,6 +4622,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3455,6 +4653,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3477,6 +4682,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3499,6 +4711,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3520,6 +4739,10 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3542,6 +4765,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3564,6 +4791,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3586,6 +4817,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3608,6 +4843,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3630,6 +4869,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3653,6 +4896,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3677,6 +4924,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3700,6 +4951,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3722,6 +4977,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3745,6 +5004,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3769,6 +5032,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3792,6 +5059,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3814,6 +5085,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3837,6 +5112,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3861,6 +5140,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3884,6 +5167,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3906,6 +5193,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3929,6 +5220,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3953,6 +5248,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3977,6 +5276,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4001,6 +5304,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4025,6 +5332,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4049,6 +5360,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4072,6 +5387,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4094,6 +5413,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4117,6 +5440,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4141,6 +5468,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4165,6 +5496,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4189,6 +5524,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4213,6 +5552,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4237,6 +5580,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4260,6 +5607,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4282,6 +5633,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4304,6 +5659,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4325,6 +5684,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4346,6 +5709,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4367,6 +5734,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4388,6 +5759,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4409,6 +5784,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4431,6 +5810,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4454,6 +5837,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4476,6 +5863,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4497,6 +5888,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4519,6 +5914,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4542,6 +5941,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4564,6 +5967,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4585,6 +5992,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4607,6 +6018,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4630,6 +6045,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4652,6 +6071,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4673,6 +6096,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4695,6 +6122,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4718,6 +6149,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4741,6 +6176,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4764,6 +6203,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4787,6 +6230,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4810,6 +6257,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4832,6 +6283,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4853,6 +6308,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4875,6 +6334,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4898,6 +6361,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4921,6 +6388,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4944,6 +6415,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4967,6 +6442,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4990,6 +6469,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5012,6 +6495,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -5033,6 +6520,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5054,6 +6545,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5076,6 +6571,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5098,6 +6597,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5120,6 +6623,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5142,6 +6649,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5164,6 +6675,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5187,6 +6702,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5211,6 +6730,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5234,7 +6757,11 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> + + +using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; + +template <> struct MMA_Traits { using ValTypeD = int32_t; @@ -5256,6 +6783,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5279,6 +6810,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5303,6 +6838,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5326,6 +6865,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5348,6 +6891,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5371,6 +6918,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5395,6 +6946,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5418,6 +6973,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5440,6 +6999,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5463,6 +7026,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5487,6 +7054,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5511,6 +7082,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5535,6 +7110,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5559,6 +7138,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5583,6 +7166,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5606,6 +7193,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5628,6 +7219,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5651,6 +7246,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5675,6 +7274,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5699,6 +7302,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5723,6 +7330,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5747,6 +7358,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5771,6 +7386,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5794,6 +7413,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5816,6 +7439,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5838,6 +7465,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5859,6 +7490,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5880,6 +7515,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5901,6 +7540,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5922,6 +7565,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5943,6 +7590,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5965,6 +7616,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5988,6 +7643,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6010,6 +7669,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6031,6 +7694,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6053,6 +7720,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6076,6 +7747,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6098,6 +7773,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6119,6 +7798,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6141,6 +7824,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6164,6 +7851,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6186,6 +7877,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6207,6 +7902,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6229,6 +7928,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6252,6 +7955,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6275,6 +7982,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6298,6 +8009,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6321,6 +8036,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6344,6 +8063,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6366,6 +8089,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6387,6 +8114,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6409,6 +8140,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6432,6 +8167,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6455,6 +8194,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6478,6 +8221,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6501,6 +8248,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6524,6 +8275,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6546,6 +8301,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6567,6 +8326,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6588,6 +8351,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6610,6 +8377,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6632,6 +8403,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6654,6 +8429,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6676,6 +8455,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6698,6 +8481,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6721,6 +8508,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6745,6 +8536,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6768,6 +8563,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6790,6 +8589,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6813,6 +8616,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6837,6 +8644,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6860,6 +8671,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6882,6 +8697,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6905,6 +8724,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6929,6 +8752,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6952,6 +8779,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6974,6 +8805,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6997,6 +8832,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7021,6 +8860,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7045,6 +8888,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7069,6 +8916,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7093,6 +8944,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7117,6 +8972,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7140,6 +8999,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7162,6 +9025,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7185,6 +9052,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7209,6 +9080,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7233,6 +9108,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7257,6 +9136,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7281,6 +9164,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7305,6 +9192,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7328,6 +9219,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7350,6 +9245,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7372,6 +9271,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7393,6 +9296,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7414,6 +9321,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7435,6 +9346,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7456,6 +9371,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7477,6 +9396,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7499,6 +9422,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7522,6 +9449,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7544,6 +9475,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7565,6 +9500,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7587,6 +9526,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7610,6 +9553,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7632,6 +9579,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7653,6 +9604,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7675,6 +9630,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7698,6 +9657,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7720,6 +9683,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7741,6 +9708,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7763,6 +9734,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7786,6 +9761,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7809,6 +9788,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7832,6 +9815,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7855,6 +9842,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7878,6 +9869,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7900,6 +9895,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7921,6 +9920,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7943,6 +9946,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7966,6 +9973,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7989,6 +10000,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -8012,6 +10027,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8035,6 +10054,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -8058,6 +10081,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8080,6 +10107,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -8101,6 +10132,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8122,6 +10157,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8144,6 +10183,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8166,6 +10209,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8188,6 +10235,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8210,6 +10261,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8232,6 +10287,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8255,6 +10314,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8279,6 +10342,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8302,6 +10369,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8324,6 +10395,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8347,6 +10422,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8371,6 +10450,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8394,6 +10477,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8416,6 +10503,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8439,6 +10530,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8463,6 +10558,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8486,6 +10585,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8508,6 +10611,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8531,6 +10638,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8555,6 +10666,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8579,6 +10694,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8603,6 +10722,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8627,6 +10750,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8651,6 +10778,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8674,6 +10805,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8696,6 +10831,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8719,6 +10858,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8743,6 +10886,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8767,6 +10914,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8791,6 +10942,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8815,6 +10970,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8839,6 +10998,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8862,6 +11025,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8884,6 +11051,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8906,6 +11077,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -8927,6 +11102,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8948,6 +11127,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -8969,6 +11152,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8990,6 +11177,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9011,6 +11202,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9033,6 +11228,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9056,6 +11255,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9078,6 +11281,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9099,6 +11306,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9121,6 +11332,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9144,6 +11359,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9166,6 +11385,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9187,6 +11410,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9209,6 +11436,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9232,6 +11463,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9254,6 +11489,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9275,6 +11514,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9297,6 +11540,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9320,6 +11567,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9343,6 +11594,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9366,6 +11621,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9389,6 +11648,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9412,6 +11675,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9434,6 +11701,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9455,6 +11726,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9477,6 +11752,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9500,6 +11779,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9523,6 +11806,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9546,6 +11833,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9569,6 +11860,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9592,6 +11887,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9614,6 +11913,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9635,6 +11938,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9656,6 +11963,13 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9678,6 +11992,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9699,6 +12020,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9721,6 +12049,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9742,6 +12077,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9764,6 +12106,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9785,6 +12134,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9807,6 +12163,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9828,6 +12191,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9850,6 +12220,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9871,6 +12248,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9893,6 +12277,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9915,6 +12306,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9939,6 +12337,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9962,6 +12367,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9986,6 +12398,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10008,6 +12427,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10028,7 +12454,14 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; template struct MMA_Traits> @@ -10051,6 +12484,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10073,6 +12513,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10095,6 +12542,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10119,6 +12573,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10142,6 +12603,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10166,6 +12634,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10188,6 +12663,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10210,6 +12692,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10231,6 +12720,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10253,6 +12749,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10275,6 +12778,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10299,6 +12809,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10322,6 +12839,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10346,6 +12870,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10368,6 +12899,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10390,6 +12928,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10411,6 +12956,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10433,6 +12985,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10455,6 +13014,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10479,6 +13045,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10502,6 +13075,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10526,6 +13106,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10549,6 +13136,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10573,6 +13167,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10596,6 +13197,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10620,6 +13228,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10643,6 +13258,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10667,6 +13289,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10690,6 +13319,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10714,6 +13350,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10736,6 +13379,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10758,6 +13408,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10779,6 +13436,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10801,6 +13465,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10823,6 +13494,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10847,6 +13525,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10870,6 +13555,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10894,6 +13586,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10917,6 +13616,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10941,6 +13647,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10964,6 +13677,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10988,6 +13708,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11011,6 +13738,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11035,6 +13769,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11058,6 +13799,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11082,6 +13830,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11104,6 +13859,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11126,6 +13888,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11147,6 +13916,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11169,6 +13945,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11190,6 +13973,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11212,6 +14002,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11233,6 +14030,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11255,6 +14059,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11276,6 +14087,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11298,6 +14116,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11319,6 +14144,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11341,6 +14173,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11362,6 +14201,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11384,6 +14230,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11405,6 +14258,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11427,6 +14287,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11449,6 +14316,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11473,6 +14347,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11496,6 +14377,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11520,6 +14408,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11542,6 +14437,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11564,6 +14466,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11585,6 +14494,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11607,6 +14523,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11629,6 +14552,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11653,6 +14583,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11676,6 +14613,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11700,6 +14644,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11722,6 +14673,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11744,6 +14702,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11765,6 +14730,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11787,6 +14759,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11809,6 +14788,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11833,6 +14819,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11856,6 +14849,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11880,6 +14880,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11902,6 +14909,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11924,6 +14938,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11945,6 +14966,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11967,6 +14995,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11989,6 +15024,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12013,6 +15055,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12036,6 +15085,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12060,6 +15116,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12083,6 +15146,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12107,6 +15177,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12130,6 +15207,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12154,6 +15238,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12177,6 +15268,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12201,6 +15299,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12224,6 +15329,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12248,6 +15360,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12270,6 +15389,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12292,6 +15418,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12313,6 +15446,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12335,6 +15475,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12357,6 +15504,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12381,6 +15535,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12404,6 +15565,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12428,6 +15596,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12451,6 +15626,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12475,6 +15657,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12498,6 +15687,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12522,6 +15718,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12545,6 +15748,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12569,6 +15779,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12592,6 +15809,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12616,6 +15840,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12638,6 +15869,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12660,6 +15898,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12681,6 +15926,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12703,6 +15955,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12724,6 +15983,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12746,6 +16012,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12767,6 +16040,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12789,6 +16069,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12810,6 +16097,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12832,6 +16126,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12853,6 +16154,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12875,6 +16183,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12894,7 +16209,14 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; template struct MMA_Traits> @@ -12918,6 +16240,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12939,6 +16268,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12961,6 +16297,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12983,6 +16326,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13007,6 +16357,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13030,6 +16387,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13054,6 +16418,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13076,6 +16447,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13098,6 +16476,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13119,6 +16504,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13141,6 +16533,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13163,6 +16562,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13187,6 +16593,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13210,6 +16623,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13234,6 +16654,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13256,6 +16683,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13278,6 +16712,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13299,6 +16740,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13321,6 +16769,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13343,6 +16798,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13367,6 +16829,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13390,6 +16859,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13414,6 +16890,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13436,6 +16919,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13458,6 +16948,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13479,6 +16976,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13501,6 +17005,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13523,6 +17034,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13547,6 +17065,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13570,6 +17095,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13594,6 +17126,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13617,6 +17156,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13641,6 +17187,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13664,6 +17217,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13688,6 +17248,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13711,6 +17278,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13735,6 +17309,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13758,6 +17339,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13782,6 +17370,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13804,6 +17399,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13826,6 +17428,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13847,6 +17456,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13869,6 +17485,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13891,6 +17514,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13915,6 +17545,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13938,6 +17575,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13962,6 +17606,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13985,6 +17636,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14009,6 +17667,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14032,6 +17697,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14056,6 +17728,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14079,6 +17758,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14103,6 +17789,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14126,6 +17819,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14150,6 +17850,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14172,6 +17879,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14194,6 +17908,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14215,6 +17936,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14237,6 +17965,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14258,6 +17993,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14280,6 +18022,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14301,6 +18050,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14323,6 +18079,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14344,6 +18107,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14366,6 +18136,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14387,6 +18164,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14409,6 +18193,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14430,6 +18221,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14452,6 +18250,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14473,6 +18278,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14495,6 +18307,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14517,6 +18336,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14541,6 +18367,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14564,6 +18397,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14588,6 +18428,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14610,6 +18457,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14632,6 +18486,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14653,6 +18514,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14675,6 +18543,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14697,6 +18572,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14721,6 +18603,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14744,6 +18633,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14768,6 +18664,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14790,6 +18693,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14812,6 +18722,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14833,6 +18750,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14855,6 +18779,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14877,6 +18808,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14901,6 +18839,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14924,6 +18869,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14948,6 +18900,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14970,6 +18929,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14992,6 +18958,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15013,6 +18986,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15035,6 +19015,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15057,6 +19044,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15081,6 +19075,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15104,6 +19105,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15128,6 +19136,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15151,6 +19166,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15175,6 +19197,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15198,6 +19227,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15222,6 +19258,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15245,6 +19288,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15269,6 +19319,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15292,6 +19349,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15316,6 +19380,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15338,6 +19409,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15360,6 +19438,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15381,6 +19466,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15403,6 +19495,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15425,6 +19524,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15449,6 +19555,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15472,6 +19585,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15496,6 +19616,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15519,6 +19646,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15543,6 +19677,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15566,6 +19707,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15590,6 +19738,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15613,6 +19768,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15637,6 +19799,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15660,6 +19829,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15684,6 +19860,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15706,6 +19889,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15728,6 +19918,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15749,6 +19946,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15771,6 +19975,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15790,6 +20001,7 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000..7252a0ef58 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -0,0 +1,16915 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90::SPARSE::GMMA_64x8x32_F16F16F16_SS, etc +#include // cute::GMMA::Layout_* +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major layouts in units of Type and sparsity factor S +template +using Layout_MN_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_INTER_Atom{}.layout_b()))>; +template +using Layout_MN_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW32_Atom{}.layout_b()))>; +template +using Layout_MN_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW64_Atom{}.layout_b()))>; +template +using Layout_MN_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_Atom{}.layout_b()))>; + +// K-major layouts in units of Type and sparsity factor S +template +using Layout_K_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_INTER_Atom{}.layout_b()))>; +template +using Layout_K_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW32_Atom{}.layout_b()))>; +template +using Layout_K_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW64_Atom{}.layout_b()))>; +template +using Layout_K_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW128_Atom{}.layout_b()))>; + +// With GMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a cute::GMMAsparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as cute::GMMAsmem_desc above, plus additional static checks. + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// Metadata layouts +using ELayout_64x64 = Layout, Shape <_32>>, + Stride, Stride<_64>>>; + +using ELayout_64x32 = Layout, Shape <_16,_2>>, + Stride, Stride<_64,_8>>>; + +using ELayout_64x16 = Layout, Shape < _8,_2>>, + Stride, Stride<_64,_8>>>; + +} // namespace SM90::GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [A, E] = unzip_tensor(A_zipped); + Tensor rA = recast(A); + Tensor rE = recast(E); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + static_assert(is_same::value, "GMMA DRegisters must have void type."); + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 35d4f8fdf0..b5cfcf47d3 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -142,21 +142,8 @@ # include #endif -// -// Support -// - -#include - -// -// Basic types -// - -#include - // // Debugging utilities // -#include #include diff --git a/include/cute/container/alignment.hpp b/include/cute/container/alignment.hpp index 4cf60d899f..52e4cbadd9 100644 --- a/include/cute/container/alignment.hpp +++ b/include/cute/container/alignment.hpp @@ -54,17 +54,17 @@ is_byte_aligned(void const* const ptr) # define CUTE_ALIGNAS(n) alignas(n) #endif -template +template struct aligned_struct {}; -template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; -template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {}; -template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {}; -template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {}; -template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {}; -template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {}; -template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {}; -template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {}; -template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {}; +template struct CUTE_ALIGNAS( 1) aligned_struct< 1, Child> {}; +template struct CUTE_ALIGNAS( 2) aligned_struct< 2, Child> {}; +template struct CUTE_ALIGNAS( 4) aligned_struct< 4, Child> {}; +template struct CUTE_ALIGNAS( 8) aligned_struct< 8, Child> {}; +template struct CUTE_ALIGNAS( 16) aligned_struct< 16, Child> {}; +template struct CUTE_ALIGNAS( 32) aligned_struct< 32, Child> {}; +template struct CUTE_ALIGNAS( 64) aligned_struct< 64, Child> {}; +template struct CUTE_ALIGNAS(128) aligned_struct<128, Child> {}; +template struct CUTE_ALIGNAS(256) aligned_struct<256, Child> {}; } // end namespace cute diff --git a/include/cute/container/array_aligned.hpp b/include/cute/container/array_aligned.hpp index 9895a8da77..a9d14a1a25 100644 --- a/include/cute/container/array_aligned.hpp +++ b/include/cute/container/array_aligned.hpp @@ -30,8 +30,8 @@ **************************************************************************************************/ #pragma once -#include -#include +#include // CUTE_ALIGNAS +#include // cute::array namespace cute { diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 1963d8ce7b..6aa26bc9f0 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -181,6 +181,20 @@ struct subbyte_reference } }; +template +CUTE_HOST_DEVICE +void +print(subbyte_reference ref) { + cute::print(ref.get()); +} + +template +CUTE_HOST_DEVICE +void +pretty_print(subbyte_reference ref) { + cute::pretty_print(ref.get()); +} + // // subbyte_iterator // Random-access iterator over subbyte references diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp index c5748d84c3..d7fac42a54 100644 --- a/include/cute/container/bit_field.hpp +++ b/include/cute/container/bit_field.hpp @@ -35,9 +35,9 @@ #pragma once -#include - +#include // CUTE_HOST_DEVICE #include // uint_bit_t +#include // cute::is_same namespace cute { diff --git a/include/cute/container/cuda_types.hpp b/include/cute/container/cuda_types.hpp index 8034cb271d..fbc314e543 100644 --- a/include/cute/container/cuda_types.hpp +++ b/include/cute/container/cuda_types.hpp @@ -30,12 +30,8 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include -#include +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::integral_constant namespace cute { diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index 54d282419e..3123a68d83 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -634,14 +634,23 @@ template CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') { using cute::print; - print(s); ((void(print(Is == 0 ? '\0' : ',')), void(print(get(t)))), ...); print(e); + if (sizeof...(Is) == 0) { + print(s); + } else { + ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); + } + print(e); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') { - os << s; (void(os << (Is == 0 ? '\0' : ',') << get(t)), ...); + if (sizeof...(Is) == 0) { + os << s; + } else { + (void(os << (Is == 0 ? s : ',') << get(t)), ...); + } return os << e; } #endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index 2db934356b..a15f2c1c15 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -30,8 +30,7 @@ **************************************************************************************************/ #pragma once -#include -#include +#include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE namespace cute { diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index ceafba0d80..95d06bbdd7 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -30,12 +30,11 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::array +#include // cute::is_tuple +#include // cute::Int +#include // cute::transform /** IntTuple is an integer or a tuple of IntTuples. * This file holds utilities for working with IntTuples, @@ -92,7 +91,7 @@ template using rank_t = decltype(rank(declval())); template -static constexpr int rank_v = rank_t::value; +static constexpr auto rank_v = rank_t::value; // // shape @@ -212,7 +211,7 @@ template using depth_t = decltype(depth(declval())); template -static constexpr int depth_v = depth_t::value; +static constexpr auto depth_v = depth_t::value; // // product @@ -276,7 +275,7 @@ size(IntTuple const& a) } template -static constexpr int size_v = decltype(size(declval()))::value; +static constexpr auto size_v = decltype(size(declval()))::value; // // sum @@ -522,68 +521,31 @@ compatible(IntTupleA const& a, IntTupleB const& b) template using is_compatible = decltype(compatible(declval(), declval())); -/** Test if Shape A is weakly compatible with Shape B: - * there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B) - * Equivalently, the size of Shape B is a multiple of Shape A at each terminal of Shape A. - * weakly_compatible is a partial order on A and B: A <= B +/** Test if Shape A is evenly divided by Tiler B + * @returns Static or dynamic boolean + * @post if result is true_type, then + * size(a) == logical_divide(make_layout(shape(a)),b) will always compile + * and result in true_type. */ -template +template CUTE_HOST_DEVICE constexpr auto -weakly_compatible(IntTupleA const& a, IntTupleB const& b) +evenly_divides(Shape const& a, Tiler const& b) { - if constexpr (is_tuple::value && is_tuple::value) { - if constexpr (tuple_size::value != tuple_size::value) { + if constexpr (is_tuple::value) { + if constexpr (rank_v > rank_v) { return false_type{}; } else { - return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_compatible(x,y); }, + return transform_apply(b, a, [](auto const& x, auto const& y) { return evenly_divides(y,x); }, [](auto const&... z) { return (true_type{} && ... && z); }); } - } else if constexpr (is_integral::value) { - return size(b) % a == Int<0>{}; - } else if constexpr (is_integral::value) { - return false_type{}; } else { - return weakly_compatible(shape(a), shape(b)); + return size(a) == size(b) * size(ceil_div(shape(a), b)); } CUTE_GCC_UNREACHABLE; } -template -using is_weakly_compatible = decltype(weakly_compatible(declval(), declval())); - -/** Test if Shape A is softly compatible with Shape B: - * there exists a Shape C congruent to A such that compatible(shape_div(A,C), B) - * Equivalently, the size of Shape B divides Shape A at each terminal of Shape A. - * softly_compatible is a partial order on A and B: A <= B - */ -template -CUTE_HOST_DEVICE constexpr -auto -softly_compatible(IntTupleA const& a, IntTupleB const& b) -{ - if constexpr (is_tuple::value && is_tuple::value) { - if constexpr (tuple_size::value != tuple_size::value) { - return false_type{}; - } else { - return transform_apply(a, b, [](auto const& x, auto const& y) { return softly_compatible(x,y); }, - [](auto const&... z) { return (true_type{} && ... && z); }); - } - } else if constexpr (is_integral::value) { - return a % size(b) == Int<0>{}; - } else if constexpr (is_integral::value) { - return false_type{}; - } else { - return softly_compatible(shape(a), shape(b)); - } - - CUTE_GCC_UNREACHABLE; -} - -template -using is_softly_compatible = decltype(softly_compatible(declval(), declval())); - /** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> */ template @@ -594,7 +556,7 @@ filter_zeros(IntTupleA const& a, IntTupleB const& b) if constexpr (is_tuple::value) { return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); } else if constexpr (is_constant<0, IntTupleA>::value) { - return Int<1>{}; + return repeat_like(b, Int<1>{}); } else { return b; } @@ -899,92 +861,4 @@ elem_geq(T const& t, U const& u) { return !elem_less(t, u); } -namespace detail { - -/** Increment a (dynamic) coord lexicographically within a shape - * @pre is_congruent::value - * \code - * auto shape = make_shape(1,2,make_shape(2,3),3); - * - * int i = 0; - * for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { - * std::cout << i++ << ": " << coord << std::endl; - * } - * assert(i == size(shape)); - * \endcode - */ -template -CUTE_HOST_DEVICE constexpr -void -increment(Coord& coord, Shape const& shape) -{ - if constexpr (is_integral::value) { - ++coord; - } else { - increment(get(coord), get(shape)); - if constexpr (I+1 < tuple_size::value) { - if (back(get(coord)) == back(get(shape))) { - back(get(coord)) = 0; - increment(coord, shape); - } - } - } -} - -} // end namespace detail - -struct ForwardCoordIteratorSentinal -{}; - -// A forward iterator for a starting coordinate in a shape's domain, and a shape. -// The starting coordinate may be zero but need not necessarily be. -template -struct ForwardCoordIterator -{ - static_assert(is_congruent::value); - - CUTE_HOST_DEVICE constexpr - Coord const& operator*() const { return coord; } - - CUTE_HOST_DEVICE constexpr - ForwardCoordIterator& operator++() { detail::increment(coord, shape); return *this; } - - // Sentinel for the end of the implied range - CUTE_HOST_DEVICE constexpr - bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } - CUTE_HOST_DEVICE constexpr - bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } - CUTE_HOST_DEVICE constexpr - bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } - // NOTE: These are expensive, avoid use - CUTE_HOST_DEVICE constexpr - bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } - CUTE_HOST_DEVICE constexpr - bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } - CUTE_HOST_DEVICE constexpr - bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } - - Coord coord; - Shape const& shape; -}; - -// A forward iterator for a coordinate that starts from a provided coordinate -template -CUTE_HOST_DEVICE constexpr -auto -make_coord_iterator(Coord const& coord, Shape const& shape) -{ - return ForwardCoordIterator{coord,shape}; -} - -// A forward iterator for a coordinate that starts from zero -template -CUTE_HOST_DEVICE constexpr -auto -make_coord_iterator(Shape const& shape) -{ - auto coord = repeat_like(shape, int(0)); - return make_coord_iterator(coord, shape); -} - } // end namespace cute diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 60581192b0..bc1b54efbc 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -31,13 +31,13 @@ #pragma once #include - -#include #include #include +#include #include -#include #include +#include +#include // cute::sizeof_bits namespace cute { @@ -660,7 +660,7 @@ template using cosize_t = decltype(cosize(declval())); template -static constexpr int cosize_v = cosize_t::value; +static constexpr auto cosize_v = cosize_t::value; // With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well template @@ -905,6 +905,15 @@ filter_zeros(Layout const& layout) return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); } +// Replace the modes in layout that correspond to a 0 at the terminals of trg_profile with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout, IntTuple const& trg_profile) +{ + return make_layout(filter_zeros(trg_profile, layout.shape()), layout.stride()); +} + // Remove all of the 0-strides and 1-sizes // Return 1-shape if empty template @@ -1350,7 +1359,8 @@ max_common_vector(Layout const& a, /* Return a layout that distributes ShapeB over ShapeA. * * @returns Layout result - * @post softly_compatible(@a b, @a result) + * @post evenly_divides(@a b, size(@a result)) + * @post evenly_divides(@a a, @a result) * @post For all i,j in [0,size(@a result)) with i < j, @a result(i) < @a result(j). Surjective and Ordered. * @post composition(make_layout(shape(@a a)), @a result) is admissible * \code @@ -1726,8 +1736,8 @@ tile_to_shape(Layout const& block, // Assert proper division if constexpr (is_static::value) { - CUTE_STATIC_ASSERT_V(weakly_compatible(block_shape, target_shape), - "tile_to_shape: block shape does not divide the target shape."); + CUTE_STATIC_ASSERT_V(evenly_divides(target_shape, block_shape), + "tile_to_shape: block shape does not divide the target shape."); } auto product_shape = ceil_div(target_shape, block_shape); @@ -1924,92 +1934,97 @@ print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) a printf("+\n"); } -// Generic 2D Layout to Latex printer -- B&W 8-value color coding -template +struct TikzColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "white"; + } +}; + +struct TikzColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", + "black!10", "black!50", "black!30", "black!70"}; + return color_map[idx % 8]; + } +}; + +struct TikzColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + return color_map[tid % 8]; + } +}; + +// Generic 2D Layout to LaTeX printer +template CUTE_HOST_DEVICE void -print_latex(LayoutA const& layout_a) +print_latex(LayoutA const& layout_a, // (m,n) -> idx + TikzColorFn color = {}) // lambda(idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); auto layout = append<2>(layout_a, Layout<_1,_0>{}); - char const* latex_header = - "\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center,font=\\Large}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"black!00", - "black!40", - "black!20", - "black!60", - "black!10", - "black!50", - "black!30", - "black!70"}; - - // Header + // Commented print(layout) printf("%% Layout: "); print(layout); printf("\n"); - - printf(latex_header); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // Layout for (int i = 0; i < size<0>(layout); ++i) { for (int j = 0; j < size<1>(layout); ++j) { int idx = layout(i,j); - printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", - color_map[idx % 8], - i, j, - idx); + printf("\\node[fill=%s] at (%d,%d) {%d};\n", + color(idx), i, j, idx); } } - + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(size<0>(layout)), int(size<1>(layout))); // Labels - for (int i = 0, j = -1; i < size<0>(layout); ++i) { + for (int i = 0, j = -1; i < size<0>(layout); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); } - for (int j = 0, i = -1; j < size<1>(layout); ++j) { + for (int i = -1, j = 0; j < size<1>(layout); ++j) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); } -// Generic ThrVal 2D Layout to Latex TIKZ -- 8-value color coded by thread -template +// Generic ThrVal 2D Layout to LaTeX TikZ +template CUTE_HOST_DEVICE void -print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and tid -> thr_idx +print_latex(Layout const& layout, // (m,n) -> (tid,vid) + ThrID const& thr, // tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - char const* latex_header = - "\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - + // Commented prints + printf("%% Layout: "); print(layout); printf("\n"); + printf("%% ThrID : "); print(thr); printf("\n"); // Header - printf("%% layout: "); print(layout); printf("\n"); - printf("%% thrid: "); print(thr); printf("\n\n"); - - printf(latex_header); + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // Layout for (int i = 0; i < size<0>(layout); ++i) { @@ -2018,13 +2033,15 @@ print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and int val_idx = layout(i,j) / size(thr); int thr_idx = thr(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), i, j, thr_idx, val_idx); } } - + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(size<0>(layout)), int(size<1>(layout))); // Labels for (int i = 0, j = -1; i < size<0>(layout); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); @@ -2034,13 +2051,8 @@ print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); } } // end namespace cute - -// -// Extended Layouts -// - -#include diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index fb62541cb4..3e5f836279 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::tuple +#include // cute::true_type, cute::false_type, cute::Int /* This implements a ComposedLayout of the form * LayoutA o Offset o LayoutB diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 651ff8e887..2e46905719 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -197,7 +197,7 @@ struct ArithmeticTupleIterator ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} CUTE_HOST_DEVICE constexpr - ArithTuple const& operator*() const { return coord_; } + ArithTuple operator*() const { return coord_; } template CUTE_HOST_DEVICE constexpr @@ -206,7 +206,7 @@ struct ArithmeticTupleIterator template CUTE_HOST_DEVICE constexpr auto operator+(Coord const& c) const { - return ArithmeticTupleIterator(coord_ + c); + return ArithmeticTupleIterator>(coord_ + c); } }; @@ -268,13 +268,13 @@ basis_value(SB const& e) // Apply the N... pack to another Tuple template -CUTE_HOST_DEVICE constexpr auto -basis_get(SB const& e, Tuple const& t) +CUTE_HOST_DEVICE decltype(auto) +basis_get(SB const& e, Tuple&& t) { if constexpr (is_scaled_basis::value) { - return basis_get(e.value(), get(t)); + return basis_get(e.value(), get(static_cast(t))); } else { - return t; + return static_cast(t); } CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 5aa6664a89..7dd9ea5bf0 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include -#include -#include +#include // CUTE_HOST_DEVICE + +#include // cutlass::complexm, cutlass::real, cutlass::imag, cutlass::is_complex namespace cute { diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index 169e3a0e67..571b3e3ed0 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -36,7 +36,9 @@ #include #endif -#include +#include // CUTE_STL_NAMESPACE + +#include // cutlass::int2b_t, cutlass::int4b_t namespace cute { @@ -53,8 +55,8 @@ using CUTE_STL_NAMESPACE::int32_t; using CUTE_STL_NAMESPACE::int64_t; template struct int_bit; -template <> struct int_bit< 2> { using type = cutlass::int2b_t; }; -template <> struct int_bit< 4> { using type = cutlass::int4b_t; }; +template <> struct int_bit< 2> { using type = int2_t; }; +template <> struct int_bit< 4> { using type = int4_t; }; template <> struct int_bit< 8> { using type = int8_t; }; template <> struct int_bit< 16> { using type = int16_t; }; template <> struct int_bit< 32> { using type = int32_t; }; @@ -83,9 +85,9 @@ using CUTE_STL_NAMESPACE::uint64_t; using cutlass::uint128_t; template struct uint_bit; -template <> struct uint_bit< 1> { using type = cutlass::uint1b_t; }; -template <> struct uint_bit< 2> { using type = cutlass::uint2b_t; }; -template <> struct uint_bit< 4> { using type = cutlass::uint4b_t; }; +template <> struct uint_bit< 1> { using type = uint1_t; }; +template <> struct uint_bit< 2> { using type = uint2_t; }; +template <> struct uint_bit< 4> { using type = uint4_t; }; template <> struct uint_bit< 8> { using type = uint8_t; }; template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index 46863ac286..e447103b99 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ #pragma once -#include "cute/util/print.hpp" -#include "cute/util/type_traits.hpp" -#include "cute/numeric/math.hpp" -#include "cutlass/fast_math.h" +#include // cute::max, etc +#include // cute::print +#include // __CUTE_REQUIRES, cute::is_std_integral namespace cute { @@ -65,7 +64,7 @@ struct integral_constant : C { static constexpr T value = v; using value_type = T; // Disambiguate C::operator value_type() - //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; @@ -406,6 +405,20 @@ conditional_return(false_type, TrueType&&, FalseType&& f) { return static_cast(f); } +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, C const&, C const&) { + return C{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, C const&, C const&) { + return b ? v : u; +} + // TrueType and FalseType must have a common type template CUTE_HOST_DEVICE constexpr @@ -435,7 +448,7 @@ static_value() return Int{}; } else { return Trait::value; - } + } CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 943b004982..1b1432533a 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -30,11 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::false_type, cute::true_type +#include // cute::signum +#include // __CUTE_REQUIRES namespace cute { diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 6d95165de2..e493a3a953 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include +#include // CUTE_HOST_DEVICE +#include // __CUTE_REQUIRES -#include #include namespace cute @@ -143,7 +143,7 @@ has_single_bit(T x) { // bit_width( 0b0111 ) = 3 template CUTE_HOST_DEVICE constexpr -T +int bit_width(T x) { static_assert(is_unsigned::value, "Only to be used for unsigned types."); constexpr int N = (numeric_limits::digits == 64 ? 6 : @@ -224,7 +224,7 @@ rotr(T x, int s) { // countl_zero( 0b00011100 ) = 3 template CUTE_HOST_DEVICE constexpr -T +int countl_zero(T x) { return numeric_limits::digits - bit_width(x); } @@ -235,7 +235,7 @@ countl_zero(T x) { // countl_one( 0b11100011 ) = 3 template CUTE_HOST_DEVICE constexpr -T +int countl_one(T x) { return countl_zero(~x); } @@ -246,7 +246,7 @@ countl_one(T x) { // countr_zero( 0b00011100 ) = 2 template CUTE_HOST_DEVICE constexpr -T +int countr_zero(T x) { return x == 0 ? numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB } @@ -257,7 +257,7 @@ countr_zero(T x) { // countr_one( 0b11100011 ) = 2 template CUTE_HOST_DEVICE constexpr -T +int countr_one(T x) { return countr_zero(~x); } @@ -285,7 +285,7 @@ popcount(T x) { // Computes the result of bitwise left-shift template CUTE_HOST_DEVICE constexpr -T +auto shiftl(T x, int s) { return s >= 0 ? (x << s) : (x >> -s); } @@ -293,7 +293,7 @@ shiftl(T x, int s) { // Computes the result of bitwise right-shift template CUTE_HOST_DEVICE constexpr -T +auto shiftr(T x, int s) { return s >= 0 ? (x >> s) : (x << -s); } diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp index 02c700254b..07444331ff 100644 --- a/include/cute/numeric/numeric_types.hpp +++ b/include/cute/numeric/numeric_types.hpp @@ -30,12 +30,11 @@ **************************************************************************************************/ #pragma once -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::int2_t, cute::int4_t, etc -#include -#include +#include // cutlass::sizeof_bits +#include // cutlass::float_e4m3_t, cutlass::float_e5m2_t, etc namespace cute { @@ -72,4 +71,65 @@ using cutlass::int4b_t; using cutlass::uint4b_t; using cutlass::bin1_t; -} // end namespace cute + +// +// Print utility +// + +CUTE_HOST_DEVICE +void +print(half_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(bfloat16_t a) { + printf("%f", static_cast(a)); +} + + +CUTE_HOST_DEVICE +void +print(tfloat32_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(float_e4m3_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(float_e5m2_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(bfloat16_t v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(half_t v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(tfloat32_t v) { + printf("%*.2e", 10, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(float_e4m3_t t) { + printf("%*.2f", 8, static_cast(t)); +} + +CUTE_HOST_DEVICE void +pretty_print(float_e5m2_t t) { + printf("%*.2f", 8, static_cast(t)); +} + +} // namespace cute diff --git a/include/cute/numeric/real.hpp b/include/cute/numeric/real.hpp index f797bc13a1..4ce58dfa18 100644 --- a/include/cute/numeric/real.hpp +++ b/include/cute/numeric/real.hpp @@ -35,6 +35,24 @@ namespace cute { +/// Generic add +template +CUTE_HOST_DEVICE constexpr +void +add(C& c, A const& a, B const& b) +{ + c = a + b; +} + +/// Generic multiply +template +CUTE_HOST_DEVICE constexpr +void +mul(C& c, A const& a, B const& b) +{ + c = a * b; +} + /// Generic fused multiply-add template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 604477a0d3..4cfa129cce 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -30,17 +30,13 @@ **************************************************************************************************/ #pragma once -#include +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include +#include // cute::subbyte_iterator +#include // cute::true_type, cute::false_type +#include // sizeof_bits -#include -#include // sizeof_bits -#include -#include - -#include - -#include -#include namespace cute { @@ -50,6 +46,9 @@ namespace cute // Subbyte Types: uint2_t, uint4_t, etc // Requires construction of a subbyte_iterator in order to properly // resolve each element in byte-addressed memory. +// Sparse Types: sparse_elem +// A type that holds one physical element meant to represent S number of logical elements. +// Requires construction of a sparse_ptr that emulates access to the S logical elements. // template @@ -57,6 +56,11 @@ CUTE_HOST_DEVICE constexpr auto recast_ptr(void* ptr) { + if constexpr (is_sparse::value) { + constexpr int sparsity = NewT::sparsity; + NewT* p = reinterpret_cast(ptr); + return make_sparse_ptr(p); + } else if constexpr (cute::is_subbyte_v) { return subbyte_iterator(ptr); } else { @@ -70,6 +74,11 @@ CUTE_HOST_DEVICE constexpr auto recast_ptr(void const* ptr) { + if constexpr (is_sparse::value) { + constexpr int sparsity = NewT::sparsity; + NewT const* p = reinterpret_cast(ptr); + return make_sparse_ptr(p); + } else if constexpr (cute::is_subbyte_v) { return subbyte_iterator(ptr); } else { diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index db5d3dcfc4..90ca0ceb6e 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include // sizeof_bits +#include // CUTE_HOST_DEVICE +#include // cute::sizeof_bits +#include // cute::declval, cute::void_t, etc namespace cute { diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp index 08751eb169..eb8d7e452e 100644 --- a/include/cute/pointer_flagged.hpp +++ b/include/cute/pointer_flagged.hpp @@ -30,15 +30,13 @@ **************************************************************************************************/ #pragma once -#include - -#include // cast_smem_ptr_to_uint - -#include -#include -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::ComposedLayout +#include // cute::make_smem_ptr +#include // cute::is_sparse +#include // cute::make_swizzle_ptr +#include // cute::cast_smem_ptr_to_uint +#include // cute::Int namespace cute { @@ -124,6 +122,47 @@ as_position_independent_swizzle_tensor(Tensor&& tensor) CUTE_GCC_UNREACHABLE; } +// A model of a nullptr sparse_ptr> with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +template +struct smem_sparse_ptr_flag_bits : Int<0> {}; + +template +using smem_sparse_ptr_flag = smem_sparse_ptr_flag_bits; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(is_smem::value, "Expected smem."); + static_assert(is_sparse_ptr::value, "Expected sparse iter"); + static_assert(is_sparse>::value, "Expected sparse elem"); + static_assert(S == iter_value_t::sparsity, "Expected sparsity S"); + static_assert(B == sizeof_bits::raw_type>::value, "Expected B-bit pointer type"); + return make_tensor(make_swizzle_ptr(ptr, layout.layout_a()), layout.layout_b()); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + static_assert(dependent_false, "Not implemented for safety"); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + static_assert(dependent_false, "Not implemented for safety"); +} + // // Display utilities // @@ -151,4 +190,10 @@ CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) printf("smem_ptr[%db](unset)", B); } +template +CUTE_HOST_DEVICE void print(smem_sparse_ptr_flag_bits) +{ + printf("smem_sparse<%d>_ptr[%db](unset)", S, B); +} + } // end namespace cute diff --git a/include/cute/pointer_sparse.hpp b/include/cute/pointer_sparse.hpp new file mode 100644 index 0000000000..ccae458650 --- /dev/null +++ b/include/cute/pointer_sparse.hpp @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include // cute::false_type, cute::true_type +#include // cute::ratio + +namespace cute +{ + +// A data type that holds one physical element meant to represent Sparsity number of logical elements +// This class is purposely not compatible with anything -- know what you're doing if you attempt to use it +template +struct sparse_elem +{ + static constexpr int sparsity = Sparsity; + using raw_type = T; + T elem_; + + CUTE_HOST_DEVICE constexpr + explicit sparse_elem(T const& elem = {}) : elem_(elem) {} + + CUTE_HOST_DEVICE constexpr friend bool operator==(sparse_elem const& a, sparse_elem const& b) { return a.elem_ == b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator!=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ != b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator< (sparse_elem const& a, sparse_elem const& b) { return a.elem_ < b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator<=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ <= b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator> (sparse_elem const& a, sparse_elem const& b) { return a.elem_ > b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator>=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ >= b.elem_; } +}; + +template +struct is_sparse : false_type {}; +template +struct is_sparse : is_sparse {}; +template +struct is_sparse> : true_type {}; +template +static constexpr auto is_sparse_v = is_sparse::value; + +// Overload sizeof_bits for sparse_elem. +// Much like subbyte element types, this is the effective number of bits in a sparse_elem +// rather than actual physical bits that may be used in storing one. Also like subbyte element +// types, modified iterators are required to properly index and access sparse_elems. +// +// Defining sizeof_bits like this makes reasonable expressions like N * sizeof_bits_v meaningful +// even when E is subbyte or sparse. However, this also means that sparse_elem can rather easily be +// confused with subbyte elements and special care should be taken with each. +template +struct sizeof_bits> { + // Simple implementation that conforms to sizeof_bits + //static constexpr auto value = sizeof_bits::value / S; + //static_assert(value != 0, "sizeof_bits=0 detected. Sparsity is larger than width."); + //static_assert((sizeof_bits::value % S) == 0, "Width needs to be a multiple of sparsity.") + + // Interesting experiment that allows any sparsity level to be used by potentially presenting + // an integral_ratio rather than size_t. This is valid in most integer expressions as well. + static constexpr auto value = cute::ratio(cute::Int>{}, cute::Int{}); +}; + +// +// sparse_ptr +// + +template +struct is_sparse_ptr : false_type {}; +template +struct is_sparse_ptr> : is_sparse_ptr {}; + +template +struct sparse_ptr : iter_adaptor> +{ + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + // Sanity, for now + static_assert(is_sparse::value, "Enforce sparse value-type"); + static_assert(Sparsity == iter_value_t::sparsity, "Enforce sparsity S"); + static_assert(not is_sparse_ptr::value, "Enforce sparse singleton"); + + template + CUTE_HOST_DEVICE constexpr + sparse_ptr operator+(Index const& i) const { + // Only allow offset by multiples of the sparsity factor, + // else the misalignments become a bug. E.g. (sparse_ptr<8,I>{} + 7) + 7 + // Motivation for subsparse_iterator or generalization of subbyte_iterator? + assert(i % Sparsity == 0); + return {this->get() + i / Sparsity}; + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { + // Allow offset by any value and dereference. + // Not implemented in terms of sparse_ptr::op+() + return *(this->get() + i / Sparsity); + } +}; + +template +struct is_sparse_ptr> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_sparse_ptr(Iter const& iter) { + if constexpr (Sparsity == 1) { + return iter; + } else { + return sparse_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(sparse_ptr const& ptr) { + static_assert(not is_sparse::value); + return recast_ptr(ptr.get()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(sparse_ptr ptr) +{ + printf("sparse<%d>_", S); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, sparse_ptr ptr) +{ + return os << "sparse<" << S << ">_" << ptr.get(); +} +#endif + +} // end namespace cute diff --git a/include/cute/pointer_swizzle.hpp b/include/cute/pointer_swizzle.hpp index a83b485c8e..720b9b1246 100644 --- a/include/cute/pointer_swizzle.hpp +++ b/include/cute/pointer_swizzle.hpp @@ -30,13 +30,11 @@ **************************************************************************************************/ #pragma once -#include - -#include // iterator_traits -#include - -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include // cute::Swizzle, cute::get_swizzle primary template +#include // cute::iterator_traits +#include // cute::subbyte_iterator /* This implements a swizzle pointer of the form * InvolutionFn o PtrAdd @@ -107,16 +105,14 @@ struct swizzle_ptr : iter_adaptor> } }; -template // Default No-Swizzle -struct get_swizzle { using type = Swizzle<0,4,3>; }; +// +// Helper Function +// template // Found the SwizzleFn struct get_swizzle> { using type = SwizzleFn; }; template // Recurse into anything with a ::iterator struct get_swizzle> : get_swizzle {}; -template -using get_swizzle_t = typename get_swizzle::type; - template CUTE_HOST_DEVICE constexpr swizzle_ptr diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index 09a02a00e7..f2d31f4e34 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -30,10 +30,16 @@ **************************************************************************************************/ #pragma once -#include -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_integral +#include // cute::seq +#include // cute::divmod +#include // cute::basis_get +#include // cute::identity +#include // cute::fold +#include // cute::is_congruent namespace cute { @@ -433,7 +439,7 @@ compact_order(Shape const& shape, Order const& order) auto flat_order = flatten_to_tuple(order); // Find the largest static element of order auto max_order = cute::fold(flat_order, Int<0>{}, [](auto v, auto order) { - if constexpr (is_constant::value) { + if constexpr (is_constant::value) { return order; } else { return v; @@ -474,4 +480,119 @@ compact_order(Shape const& shape, GenRowMajor const& major) return compact_major(shape); } +// +// Coordinate iterator +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape, Order const& order) +{ + ++basis_get(get<0>(order), coord); + cute::for_each(make_range<1, tuple_size::value>{}, [&](auto i){ + if (basis_get(get(order), coord) == basis_get(get(order), shape)) { + basis_get(get(order), coord) = 0; + ++basis_get(get(order), coord); + } + }); +} + +/** Increment a (dynamic) coord colexicographically within a shape + * @pre is_congruent::value + * \code + * auto shape = make_shape(1,2,make_shape(2,3),3); + * auto coord = repeat_like(shape, 0); + * + * for (int i = 0; i < size(shape); ++i) { + * std::cout << i << ": " << coord << std::endl; + * increment(coord, shape); + * } + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape) +{ + increment(coord, shape, flatten_to_tuple(make_basis_like(shape))); +} + +} // end namespace detail + +struct ForwardCoordIteratorSentinel +{}; + +// A forward iterator for a starting coordinate in a shape's domain, and a shape. +// The starting coordinate may be zero but need not necessarily be. +template +struct ForwardCoordIterator +{ + static_assert(is_congruent::value); + + CUTE_HOST_DEVICE constexpr + Coord const& operator*() const { return coord; } + CUTE_HOST_DEVICE constexpr + ForwardCoordIterator& operator++() { detail::increment(coord, shape, Order{}); return *this; } + // Sentinel for the end of the implied range + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIteratorSentinel const&) const { return basis_get(back(Order{}), coord) == basis_get(back(Order{}), shape); } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIteratorSentinel const&) const { return basis_get(back(Order{}), coord) != basis_get(back(Order{}), shape); } + // NOTE: These are expensive, avoid use + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } + + Coord coord; + Shape const& shape; +}; + +// A forward iterator for a coordinate that starts from a provided coordinate and increments in a prescribed order +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + static_assert(is_congruent::value); + static_assert(is_congruent::value); + static_assert(is_congruent::value); + auto flat_order = flatten_to_tuple(Order{}); + auto inv_order = transform(make_seq{}, [&](auto i){ return find(flat_order, i); }); + auto basis_order = transform_leaf(inv_order, [&](auto i) { return get(flatten_to_tuple(make_basis_like(shape))); }); + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from a provided coordinate and increments colex +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + static_assert(is_congruent::value); + auto basis_order = flatten_to_tuple(make_basis_like(shape)); + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from zero and increments in a prescribed order +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + return make_coord_iterator(repeat_like(shape, int(0)), shape); +} + +// A forward iterator for a coordinate that starts from zero and increments colex +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + return make_coord_iterator(repeat_like(shape, int(0)), shape); +} + } // end namespace cute diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index 9ceb0d32b0..52abf856dd 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -30,13 +30,11 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::is_tuple +#include // cute::constant +#include // cute::max, cute::min +#include // cute::transform_apply namespace cute { @@ -488,4 +486,13 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) } #endif // !defined(__CUDACC_RTC__) +// +// Helper Function +// +template // Default No-Swizzle +struct get_swizzle { using type = Swizzle<0,4,3>; }; + +template +using get_swizzle_t = typename get_swizzle::type; + } // end namespace cute diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 82e51c79c6..1324360eba 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -30,13 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include - -#include -#include // get_swizzle +#include // CUTE_HOST_DEVICE +#include // cute::Layout +#include // cute::ComposedLayout +#include // cute::Swizzle, cute::get_swizzle primary template /* Specialized functionality for a ComposedLayout of the form * InvolutionFn o Offset o LayoutB @@ -57,6 +54,9 @@ namespace cute { +// +// Helper Function +// template struct get_swizzle,Offset,LayoutB>> { using type = Swizzle; }; @@ -193,7 +193,7 @@ make_swizzle_strides(true_type, // 0 Z DC // 1 -Z DC - return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); + return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z * Int<(1 << I)>{}, -Z * Int<(1 << I)>{})...); } template @@ -214,7 +214,7 @@ make_swizzle_strides(false_type, // 0 Y+Z Y-Z // 1 DC DC - return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); + return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) * Int<(1 << I)>{}, (Y-Z) * Int<(1 << I)>{})...); } } // end namespace detail @@ -240,16 +240,6 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout // The portion of the layout that is not yet consumed auto sliced_layout = slice(coord, layout.layout_b()); - // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay - - // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] - // (this also tests that shape/stride of layout compose with swizzle) - auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); - // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) - auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); - // Determine if any active bits collide under the swizzle - auto hit_ZandY = !(swizzle_active_bits & ~layout.layout_a()(swizzle_active_bits)); - // The portion of the layout that we are consuming now auto diced_layout = dice(coord, layout.layout_b()); auto diced_coord = dice(coord, coord); @@ -269,8 +259,16 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal // If Layout's codomain hits on neither Y NOR Z, then it's static-normal - // Test the sliced layout for hit_X & hit_Y for potential decay - if constexpr (is_constant::value) + // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay + + // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] + // (this also tests that shape/stride of layout compose with swizzle) + auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + [[maybe_unused]] auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); + + // Determine if any active bits collide under the swizzle for potential decay + if constexpr (is_constant<0, decltype(not (swizzle_active_bits & ~swizzle(swizzle_active_bits)))>::value) { // Hits on Y AND Z, so it's not reducible return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); } else @@ -459,7 +457,7 @@ CUTE_HOST_DEVICE constexpr auto max_alignment(Swizzle const&) { - return Int{}; + return Int<1 << M>{}; } template diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index a45cbd0132..3f3335b63d 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -37,7 +37,10 @@ // #include +#include #include +#include + // // Tensor Algorithms // diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index da0e245636..61eefc5060 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -41,18 +41,16 @@ #pragma once -#include - -#include -#include -#include - -#include -#include -#include - -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::Shape +#include // cute::is_composed_layout +#include // cute::recast_ptr +#include // cute::iterator_traits +#include // cute::array_aligned +#include // cute::array_subbyte +#include // cute::tuple +#include // cute::is_integral +#include // __CUTE_REQUIRES namespace cute { @@ -69,7 +67,7 @@ namespace cute // iterator begin(); // }; -template +template struct ArrayEngine { using Storage = typename conditional<(sizeof_bits::value % 8 == 0), @@ -85,6 +83,24 @@ struct ArrayEngine CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } }; +// Specialization for sparse_elem tensor allocation/iteration +template +struct ArrayEngine, N> +{ + static_assert(N % S == 0, "Expected a multiple of the sparsity."); + using value_type = sparse_elem; + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = sparse_ptr*>; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return recast_ptr(storage_.begin()); } + CUTE_HOST_DEVICE constexpr auto begin() { return recast_ptr(storage_.begin()); } +}; + template struct ViewEngine { @@ -622,6 +638,30 @@ filter_zeros(Tensor&& tensor) { return make_tensor(tensor.data(), filter_zeros(tensor.layout())); } +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor const& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + // Remove all of the 0-strides and 1-sizes template CUTE_HOST_DEVICE constexpr @@ -755,10 +795,10 @@ auto max_common_vector(Tensor const& a, Tensor const& b) { - using SrcType = typename Tensor::value_type; - using DstType = typename Tensor::value_type; - using SrcRef = typename Tensor::reference; - using DstRef = typename Tensor::reference; + using SrcType = typename SrcEngine::value_type; + using SrcRef = typename SrcEngine::reference; + using DstType = typename DstEngine::value_type; + using DstRef = typename DstEngine::reference; // Determine if vectorization candidates at all if constexpr (// Should be the same value_types, else the copy is also performing a cast @@ -795,10 +835,10 @@ auto max_common_layout(Tensor const& a, Tensor const& b) { - using SrcType = typename Tensor::value_type; - using DstType = typename Tensor::value_type; - using SrcRef = typename Tensor::reference; - using DstRef = typename Tensor::reference; + using SrcType = typename SrcEngine::value_type; + using SrcRef = typename SrcEngine::reference; + using DstType = typename DstEngine::value_type; + using DstRef = typename DstEngine::reference; // Determine if vectorization candidates at all if constexpr (// Should be the same value_types, else the copy is also performing a cast diff --git a/include/cute/tensor_predicate.hpp b/include/cute/tensor_predicate.hpp index 6814647071..9c8a2ba614 100644 --- a/include/cute/tensor_predicate.hpp +++ b/include/cute/tensor_predicate.hpp @@ -30,9 +30,8 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::true_type namespace cute { diff --git a/include/cute/tensor_zip.hpp b/include/cute/tensor_zip.hpp new file mode 100644 index 0000000000..6d70ffc847 --- /dev/null +++ b/include/cute/tensor_zip.hpp @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::tuple + +namespace cute +{ + +// A tuple of Iterators that can be offset asymmetrically +// Note that this only accepts op+(tuple) and op[tuple] +// where each iterator will be offset by its respective index only. +// READ-ONLY for now until cute::tuple can be constructed with references. +template +struct ZipIterator +{ + using value_type = cute::tuple...>; + using element_type = cute::tuple...>; + // NOTE: cute::tuple does not support constructions with references at the moment. + // Consider fixes and/or an implementation of std::forward_as_tuple. + // For now, use a cute::tuple of value_types instead, which makes this Iterator READ-ONLY. + //using reference = cute::tuple...>; + using reference = value_type; + + ZipIterator() = delete; + + CUTE_HOST_DEVICE constexpr + ZipIterator(Iters... iters) + : iters_(iters...) + {} + + CUTE_HOST_DEVICE constexpr + ZipIterator(cute::tuple const& iters) + : iters_(iters) + {} + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return cute::apply(iters_, [](auto&&... args) { return reference(*args...); }); + } + + template + CUTE_HOST_DEVICE constexpr + ZipIterator operator+(cute::tuple const& idxs) const { + static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators."); + return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; }); + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](cute::tuple const& idxs) const { + return *(*this + idxs); + } + + cute::tuple iters_; +}; + +//------------------------------------------------------------------------------ +// type traits + +template +struct is_rmem> : conjunction...> {}; +template +struct is_smem> : conjunction...> {}; +template +struct is_gmem> : conjunction...> {}; +// A tuple of Layouts that operates on each Layout symmetrically +// The Layouts need to have compatible shapes and ranks. +// The ZipLayout presents the intersection of the domain of its component Layouts. +// E.g. all Layouts accept 1D coords and ZipLayout does as well. +// The ZipLayout returns the union of the codomain of its component Layouts. +// E.g. all Layouts return an integer so ZipLayout returns a tuple of integers. +template +struct ZipLayout +{ + static constexpr int rank = (int(0) | ... | Layouts::rank); + + static_assert((is_layout::value && ...), "All template parameters must be layouts"); + static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank"); + + CUTE_HOST_DEVICE constexpr + ZipLayout(Layouts const&... layouts) + : layouts_(layouts...) + {} + + CUTE_HOST_DEVICE constexpr + ZipLayout(cute::tuple const& layouts) + : layouts_(layouts) + {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); })); + } else { + return cute::transform(layouts_, [&] (auto layout) { return layout(coord); }); + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + cute::tuple layouts_; +}; + +template +struct is_layout> : true_type {}; + +// +// make_zip_tensor and unzip_tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_zip_tensor(Tensor const&... tensors) +{ + return make_tensor(ZipIterator(tensors.data()...), + ZipLayout(tensors.layout()...)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +unzip_tensor(Tensor const& tensor) +{ + return cute::transform(tensor.data().iters_, tensor.layout().layouts_, + [](auto iter, auto layout) { return make_tensor(iter, layout); }); +} + +// +// Utilities +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(ZipLayout const& layouts) +{ + return rank(get<0>(layouts.layouts_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +size(ZipLayout const& layouts) +{ + return size(get<0>(layouts.layouts_)); +} + +// +// Manipulation +// + +// Extend each component layout to rank-N by appending Layout @a x. +template +CUTE_HOST_DEVICE constexpr +auto +append(ZipLayout const& layouts, + Layout const& x = {}) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append(t, x); })); +} + +// Extend each component layout to rank-N by prepending Layout @a x. +template +CUTE_HOST_DEVICE constexpr +auto +prepend(ZipLayout const& layouts, + Layout const& x = {}) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend(t, x); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ZipLayout const& layouts, + Tiler const& tiler) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ZipLayout const& layouts, + Tiler const& tiler) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); })); +} + +// Return by calling slice_and_offset and all component layouts. +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, ZipLayout const& layouts) +{ + auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); })); + return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result)); +} + +} // end namespace cute diff --git a/include/cute/underscore.hpp b/include/cute/underscore.hpp index 212f42d7fa..e9d80fe5b5 100644 --- a/include/cute/underscore.hpp +++ b/include/cute/underscore.hpp @@ -30,12 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include -#include +#include // CUTE_INLINE_CONSTANT, CUTE_HOST_DEVICE +#include // cute::is_tuple +#include // cute::false_type, cute::true_type namespace cute { diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index 6463e8684f..6bfe6c0a1e 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -30,9 +30,8 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::is_valid // // CUDA compatible print and printf @@ -156,50 +155,45 @@ print(char const* format) { // pretty printing // -template -CUTE_HOST_DEVICE void -pretty_print(T const& v) { - printf(" "); print(v); -} - CUTE_HOST_DEVICE void -pretty_print(bool const& v) { +pretty_print(bool v) { printf("%*d", 3, int(v)); } CUTE_HOST_DEVICE void -pretty_print(int32_t const& v) { +pretty_print(int32_t v) { printf("%*d", 5, v); } CUTE_HOST_DEVICE void -pretty_print(uint32_t const& v) { +pretty_print(uint32_t v) { printf("%*d", 5, v); } CUTE_HOST_DEVICE void -pretty_print(int64_t const& v) { +pretty_print(int64_t v) { printf("%*lld", 5, static_cast(v)); } CUTE_HOST_DEVICE void -pretty_print(uint64_t const& v) { +pretty_print(uint64_t v) { printf("%*llu", 5, static_cast(v)); } CUTE_HOST_DEVICE void -pretty_print(half_t const& v) { - printf("%*.2f", 8, float(v)); +pretty_print(float v) { + printf("%*.2e", 10, v); } CUTE_HOST_DEVICE void -pretty_print(float const& v) { - printf("%*.2e", 10, v); +pretty_print(double v) { + printf("%*.3e", 11, v); } +template CUTE_HOST_DEVICE void -pretty_print(double const& v) { - printf("%*.3e", 11, v); +pretty_print(T t) { + printf(" "); print(t); } } // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index f0eb55116d..e663b569c6 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -44,7 +44,7 @@ #include // numeric_limits #endif -#include +#include // CUTE_STL_NAMESPACE namespace cute { @@ -79,6 +79,7 @@ using CUTE_STL_NAMESPACE::is_const_v; using CUTE_STL_NAMESPACE::is_volatile; using CUTE_STL_NAMESPACE::is_volatile_v; +// Defined in cute/numeric/integral_constant.hpp // using CUTE_STL_NAMESPACE::true_type; // using CUTE_STL_NAMESPACE::false_type; @@ -278,14 +279,14 @@ struct conditional_template { // is_any_of // -/// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us -template +// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us +template struct is_any_of { constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v); }; -/// Is true if and only if T is same as (is_same_v) at least one of the types in Us -template +// Is true if and only if T is same as (is_same_v) at least one of the types in Us +template inline constexpr bool is_any_of_v = is_any_of::value; } // end namespace cute diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index cd2d7be3cb..c96897324a 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -93,12 +93,24 @@ class NamedBarrier { NamedBarrier::arrive_and_wait_internal(num_threads_, id_); } + CUTLASS_DEVICE + void arrive_and_wait_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_and_wait_internal_unaligned(num_threads_, id_); + } + CUTLASS_DEVICE void arrive() const { // Note: The value of id_ is already the final barrier id (set correctly in the constructor). NamedBarrier::arrive_internal(num_threads_, id_); } + CUTLASS_DEVICE + void arrive_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_internal_unaligned(num_threads_, id_); + } + CUTLASS_DEVICE void sync() const { NamedBarrier::arrive_and_wait(); @@ -148,11 +160,23 @@ class NamedBarrier { sync_internal(num_threads, static_cast(reserved_named_barriers)); } + private: CUTLASS_DEVICE static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_and_wait_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -161,12 +185,23 @@ class NamedBarrier { CUTLASS_DEVICE static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } + CUTLASS_DEVICE + static void arrive_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + CUTLASS_DEVICE static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); @@ -243,6 +278,7 @@ struct ClusterBarrier { "}" : : "r"(arrive_count), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_init(__LINE__, smem_addr, arrive_count); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -253,6 +289,7 @@ struct ClusterBarrier { static void wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_wait(__LINE__, smem_addr, phase); // Arbitrarily large timer value after which try-wait expires and re-tries. uint32_t ticks = 0x989680; asm volatile( @@ -276,6 +313,7 @@ struct ClusterBarrier { static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_test_wait(__LINE__, smem_addr, phase, pred); uint32_t waitComplete; asm volatile( @@ -300,6 +338,7 @@ struct ClusterBarrier { static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_try_wait(__LINE__, smem_addr, phase); uint32_t waitComplete; asm volatile( @@ -334,6 +373,7 @@ struct ClusterBarrier { : "r"(smem_addr), "r"(cta_id)); } + cutlass::arch::synclog_emit_cluster_barrier_arrive_cluster(__LINE__, smem_addr, cta_id, pred); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -350,6 +390,7 @@ struct ClusterBarrier { "}" : : "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_arrive(__LINE__, smem_addr); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -426,6 +467,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(__LINE__, smem_addr, transaction_bytes); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -463,6 +505,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_expect_transaction(__LINE__, smem_addr, transaction_bytes); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -483,6 +526,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_complete_transaction(__LINE__, smem_addr, dst_cta_id, transaction_bytes, pred); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -536,6 +580,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { CUTLASS_DEVICE void fence_barrier_init() { #if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_barrier_init(__LINE__); asm volatile( "{\n\t" "fence.mbarrier_init.release.cluster; \n" @@ -550,6 +595,7 @@ void fence_barrier_init() { CUTLASS_DEVICE void fence_view_async_shared() { #if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); asm volatile ( "{\n\t" "fence.proxy.async.shared::cta; \n" @@ -571,6 +617,7 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) { "}" : : "r"(smem_addr)); + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h new file mode 100644 index 0000000000..b0f750063c --- /dev/null +++ b/include/cutlass/arch/config.h @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Definitions for architecture macros +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #define CUTLASS_ARCH_MMA_SM90_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_ARCH_MMA_SM90A_ENABLED 1 + #endif + #endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 2) + #define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 Modifiable +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3)) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED 1 + #endif + #endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 F64 +#if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED 1 + #endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h new file mode 100644 index 0000000000..14ef197497 --- /dev/null +++ b/include/cutlass/arch/grid_dependency_control.h @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grid dependent control (GDC) helpers for programmatic dependent launches (PDL). +*/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#ifndef CUTLASS_GDC_ENABLED + #if (defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ + __CUDACC_VER_MAJOR__ >= 12 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_GDC_ENABLED + #endif +#endif + +namespace cutlass { +namespace arch { + +// Issuing the launch_dependents instruction hints a dependent kernel to launch earlier +// launch_dependents doesn't impact the functionality but the performance: +// Launching a dependent kernel too early can compete with current kernels, +// while launching too late can lead to a long latency. +CUTLASS_DEVICE +void launch_dependent_grids() { +#if (defined(CUTLASS_GDC_ENABLED)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// Issuing the griddepcontrol.wait instruction enforces no global memory access +// prior to this istruction. This ensures the correctness of global memory access +// when launching a dependent kernel earlier. +CUTLASS_DEVICE +void wait_on_dependent_grids() { +#if (defined(CUTLASS_GDC_ENABLED)) + asm volatile("griddepcontrol.wait;"); +#endif +} + +// Enable kernel-level query regarding whether the GDC feature is turned on +#if (defined(CUTLASS_GDC_ENABLED)) +static constexpr bool IsGdcGloballyEnabled = true; +#else +static constexpr bool IsGdcGloballyEnabled = false; +#endif + + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index acaa819567..cb0ba4b54b 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -326,6 +326,8 @@ struct cp_async { "cp.async only supports CacheOperation::Global when access size is 16B."); unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + cutlass::arch::synclog_emit_cp_async(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); + asm volatile( "{\n" " .reg .pred p;\n" @@ -364,6 +366,8 @@ struct cp_async_zfill { unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); int src_in_bytes = (pred_guard ? SizeInBytes : 0); + cutlass::arch::synclog_emit_cp_async_zfill(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); + asm volatile( #if CUTLASS_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), @@ -401,6 +405,8 @@ struct cp_async_nan<16, CacheOperation::Global> { OOB_NAN_F16x2, OOB_NAN_F16x2}; unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + cutlass::arch::synclog_emit_cp_async_nan(__LINE__, smem_int_ptr, global_ptr, pred_guard); + asm volatile( "{\n" " .reg .pred p;\n" @@ -434,6 +440,7 @@ CUTLASS_DEVICE void cp_async_fence() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.commit_group;\n" ::); + cutlass::arch::synclog_emit_cp_async_fence(__LINE__); #endif } @@ -444,6 +451,7 @@ template CUTLASS_DEVICE void cp_async_wait() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + cutlass::arch::synclog_emit_cp_async_wait(__LINE__, N); #endif } @@ -452,6 +460,7 @@ template <> CUTLASS_DEVICE void cp_async_wait<0>() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.wait_all;\n" ::); + cutlass::arch::synclog_emit_cp_async_wait_all(__LINE__); #endif } diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index d2b167a7ce..1183ee5e05 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -43,30 +43,7 @@ #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) - #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED - #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED - #endif - #endif -#endif - -#if (__CUDACC_VER_MAJOR__ >= 12) - #define CUTLASS_ARCH_MMA_SM90_SUPPORTED - #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - #define CUTLASS_ARCH_MMA_SM90_ENABLED - #endif - #endif -#endif - -#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) - #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED -#endif +#include "cutlass/arch/config.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index c1ffbeeb57..d2b434453e 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -37,9 +37,11 @@ #include "cutlass/cutlass.h" -#if (defined(__CUDA_ARCH__) &&\ - (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#ifndef CUDA_CTA_RECONFIG_ACTIVATED + #if (__CUDACC_VER_MAJOR__ >= 12 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif #endif namespace cutlass { diff --git a/include/cutlass/arch/synclog.hpp b/include/cutlass/arch/synclog.hpp new file mode 100644 index 0000000000..ea683859a3 --- /dev/null +++ b/include/cutlass/arch/synclog.hpp @@ -0,0 +1,1324 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Synchronization event logging for race condition debugging. +*/ + +#pragma once + +#include "cutlass/detail/helper_macros.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) + +constexpr uint32_t synclog_cap = 1 << 26; + +inline std::mutex synclog_mutex; +inline std::vector synclog_buf_list; +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +inline __device__ uint32_t* synclog_buf; +#endif + +CUTLASS_DEVICE +uint32_t* synclog_alloc(uint32_t n) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t* buf = synclog_buf; + if (buf == nullptr) return nullptr; + uint32_t last = atomicAdd(&buf[0], n); + if (last + n < synclog_cap) return buf + last + 1; + if (last >= synclog_cap) atomicAdd(&buf[0], -n); + #endif + return nullptr; +} + +CUTLASS_DEVICE +void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint64_t time64; + asm volatile ( + "mov.u64 %0, %%globaltimer;\n" + : "=l"(time64) : + ); + to[0] = header; + to[1] = line; + to[2] = time64; + to[3] = time64 >> 32; + to[4] = threadIdx.x; + to[5] = threadIdx.y; + to[6] = threadIdx.z; + to[7] = blockIdx.x; + to[8] = blockIdx.y; + to[9] = blockIdx.z; + #endif +} + +constexpr uint32_t synclog_header_none = 0; +constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3; + +constexpr bool synclog_enable_syncthreads = true; +constexpr uint32_t synclog_header_syncthreads = 1; +constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0; + +constexpr bool synclog_enable_syncwarp = true; +constexpr uint32_t synclog_header_syncwarp = 2; +constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0; + +constexpr bool synclog_enable_named_barrier_arrive_and_wait = true; +constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3; +constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2; + +constexpr bool synclog_enable_named_barrier_arrive = true; +constexpr uint32_t synclog_header_named_barrier_arrive = 4; +constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_init = true; +constexpr uint32_t synclog_header_cluster_barrier_init = 5; +constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_wait = 6; +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_test_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_try_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_arrive = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_barrier_invalidate = true; +constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 6; + +constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 6; + +constexpr bool synclog_enable_fence_barrier_init = true; +constexpr uint32_t synclog_header_fence_barrier_init = 16; +constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; + +constexpr bool synclog_enable_fence_view_async_shared = true; +constexpr uint32_t synclog_header_fence_view_async_shared = 17; +constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_wait = true; +constexpr uint32_t synclog_header_cp_async_wait = 18; +constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_cp_async_wait_all = true; +constexpr uint32_t synclog_header_cp_async_wait_all = 19; +constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_fence = true; +constexpr uint32_t synclog_header_cp_async_fence = 20; +constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_nan = true; +constexpr uint32_t synclog_header_cp_async_nan = 21; +constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cp_async_zfill = true; +constexpr uint32_t synclog_header_cp_async_zfill = 22; +constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cp_async = true; +constexpr uint32_t synclog_header_cp_async = 23; +constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5; + +constexpr bool synclog_enable_tma_load = true; +constexpr uint32_t synclog_header_tma_load = 24; +constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4; + +constexpr bool synclog_enable_tma_store = true; +constexpr uint32_t synclog_header_tma_store = 25; +constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3; + +constexpr bool synclog_enable_tma_store_arrive = true; +constexpr uint32_t synclog_header_tma_store_arrive = 26; +constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_tma_store_wait = true; +constexpr uint32_t synclog_header_tma_store_wait = 27; +constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_arrive = true; +constexpr uint32_t synclog_header_warpgroup_arrive = 28; +constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_warpgroup_wait = true; +constexpr uint32_t synclog_header_warpgroup_wait = 29; +constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_commit_batch = true; +constexpr uint32_t synclog_header_warpgroup_commit_batch = 30; +constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0; + +constexpr bool synclog_enable_wgmma_reg_smem = true; +constexpr uint32_t synclog_header_wgmma_reg_smem = 31; +constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2; + +constexpr bool synclog_enable_wgmma_smem_smem = true; +constexpr uint32_t synclog_header_wgmma_smem_smem = 32; +constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cpasync_barrier_arrive = true; +constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 3; + +CUTLASS_DEVICE +bool synclog_condition_emit() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x%NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return 0; + #endif +} + +CUTLASS_DEVICE +bool synclog_condition_print() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return false; + #endif +} + +CUTLASS_DEVICE +void synclog_print_prefix(char const* header, uint32_t at) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t line = synclog_buf[at + 1]; + uint32_t timeLo = synclog_buf[at + 2]; + uint32_t timeHi = synclog_buf[at + 3]; + uint32_t threadIdxX = synclog_buf[at + 4]; + uint32_t threadIdxY = synclog_buf[at + 5]; + uint32_t threadIdxZ = synclog_buf[at + 6]; + uint32_t blockIdxX = synclog_buf[at + 7]; + uint32_t blockIdxY = synclog_buf[at + 8]; + uint32_t blockIdxZ = synclog_buf[at + 9]; + printf( + "%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ", + header, line, + (uint64_t)timeHi << 32 | timeLo, + threadIdxX, threadIdxY, threadIdxZ, + blockIdxX, blockIdxY, blockIdxZ + ); + #endif +} + +CUTLASS_DEVICE +uint64_t synclog_mbarrier_bits(uint32_t smem_addr) { + uint64_t bits = 0; + asm volatile ( + "mbarrier.inval.shared::cta.b64 [%1];\n" + "ld.shared::cta.b64 %0, [%1];\n" + : "=l"(bits) : "r"(smem_addr) + ); + return bits; +} + +CUTLASS_DEVICE +void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { + CUTLASS_UNUSED(hi); + uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4; + printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep); +} + +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void synclog_setup() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + std::scoped_lock lock(synclog_mutex); + auto fail = [] () { + fprintf(stderr, "synclog_setup() failed\n"); + std::terminate(); + }; + int orig_device = 0; + if (cudaGetDevice(&orig_device) != cudaSuccess) { + fail(); + } + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + fail(); + } + if (synclog_buf_list.size() == 0) { + for (int device = 0; device < device_count; device++) { + uint32_t* buf = 0; + if (cudaSetDevice(device) != cudaSuccess || + cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) { + fail(); + } + synclog_buf_list.push_back(buf); + } + } + for (int device = 0; device < device_count; device++) { + uint32_t* buf = synclog_buf_list.at(device); + if (cudaSetDevice(device) != cudaSuccess || + cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess || + cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) { + fail(); + } + } + if (cudaSetDevice(orig_device) != cudaSuccess) { + fail(); + } + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncthreads(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncthreads) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncthreads); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncthreads, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncwarp(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncwarp) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncwarp); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncwarp, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive_and_wait( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_init( + uint32_t line, + uint32_t smem_addr, + uint32_t arrive_count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = arrive_count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(arrive_count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_test_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_test_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_try_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_try_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = cta_id; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_invalidate( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_invalidate) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = cta_id; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_expect_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_complete_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t dst_cta_id, + uint32_t transaction_bytes, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = dst_cta_id; + to[synclog_length_prefix + 2] = transaction_bytes; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(dst_cta_id); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_barrier_init(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_barrier_init, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_view_async_shared(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_view_async_shared) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait_all(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait_all) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_fence(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_fence) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_fence); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_fence, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_nan( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_nan) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_nan); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_nan, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_zfill( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_zfill) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_zfill, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_load( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_mbar, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_load) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_load); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_load, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_mbar; + to[synclog_length_prefix + 3] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_mbar); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_arrive(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_wait( + uint32_t line, + uint32_t count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_wait, line); + to[synclog_length_prefix + 0] = count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_arrive( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_commit_batch( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_commit_batch) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_reg_smem( + uint32_t line, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_reg_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line); + to[synclog_length_prefix + 0] = desc_b; + to[synclog_length_prefix + 1] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_smem_smem( + uint32_t line, + uint64_t desc_a, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_smem_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line); + to[synclog_length_prefix + 0] = desc_a; + to[synclog_length_prefix + 1] = desc_a >> 32; + to[synclog_length_prefix + 2] = desc_b; + to[synclog_length_prefix + 3] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_a); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cpasync_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cpasync_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +#if !defined(CUTLASS_ENABLE_SYNCLOG) +CUTLASS_DEVICE +#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +static __attribute__((__noinline__)) __device__ +#else +static __attribute__((__noinline__)) +#endif +void synclog_print() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + if (synclog_buf == nullptr || !synclog_condition_print()) { + return; + } + printf("synclog start\n"); + for (uint32_t at = 1; at < synclog_cap; ) { + uint32_t header = synclog_buf[at]; + if (header == synclog_header_none) { + break; + } + printf("synclog at %u: ", at); + if constexpr (synclog_enable_syncthreads) { + if (header == synclog_header_syncthreads) { + synclog_print_prefix("syncthreads", at); + at += synclog_length_syncthreads; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_syncwarp) { + if (header == synclog_header_syncwarp) { + synclog_print_prefix("syncwarp", at); + at += synclog_length_syncwarp; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive_and_wait) { + if (header == synclog_header_named_barrier_arrive_and_wait) { + synclog_print_prefix("named_barrier_arrive_and_wait", at); + at += synclog_length_named_barrier_arrive_and_wait; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive) { + if (header == synclog_header_named_barrier_arrive) { + synclog_print_prefix("named_barrier_arrive", at); + at += synclog_length_named_barrier_arrive; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_init) { + if (header == synclog_header_cluster_barrier_init) { + synclog_print_prefix("cluster_barrier_init", at); + at += synclog_length_cluster_barrier_init; + printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_wait) { + if (header == synclog_header_cluster_barrier_wait) { + synclog_print_prefix("cluster_barrier_wait", at); + at += synclog_length_cluster_barrier_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_test_wait) { + if (header == synclog_header_cluster_barrier_test_wait) { + synclog_print_prefix("cluster_barrier_test_wait", at); + at += synclog_length_cluster_barrier_test_wait; + printf("smem_addr=%u phase=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_try_wait) { + if (header == synclog_header_cluster_barrier_try_wait) { + synclog_print_prefix("cluster_barrier_try_wait", at); + at += synclog_length_cluster_barrier_try_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive_cluster) { + if (header == synclog_header_cluster_barrier_arrive_cluster) { + synclog_print_prefix("cluster_barrier_arrive_cluster", at); + at += synclog_length_cluster_barrier_arrive_cluster; + printf("smem_addr=%u cta_id=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive) { + if (header == synclog_header_cluster_barrier_arrive) { + synclog_print_prefix("cluster_barrier_arrive", at); + at += synclog_length_cluster_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_invalidate) { + if (header == synclog_header_cluster_barrier_invalidate) { + synclog_print_prefix("cluster_barrier_invalidate", at); + at += synclog_length_cluster_barrier_invalidate; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) { + if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { + synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); + at += synclog_length_cluster_transaction_barrier_expect_transaction; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) { + if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { + synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); + at += synclog_length_cluster_transaction_barrier_complete_transaction; + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_fence_barrier_init) { + if (header == synclog_header_fence_barrier_init) { + synclog_print_prefix("fence_barrier_init", at); + at += synclog_length_fence_barrier_init; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_fence_view_async_shared) { + if (header == synclog_header_fence_view_async_shared) { + synclog_print_prefix("fence_view_async_shared", at); + at += synclog_length_fence_view_async_shared; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait) { + if (header == synclog_header_cp_async_wait) { + synclog_print_prefix("cp_async_wait", at); + at += synclog_length_cp_async_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait_all) { + if (header == synclog_header_cp_async_wait_all) { + synclog_print_prefix("cp_async_wait_all", at); + at += synclog_length_cp_async_wait_all; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_fence) { + if (header == synclog_header_cp_async_fence) { + synclog_print_prefix("cp_async_fence", at); + at += synclog_length_cp_async_fence; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_nan) { + if (header == synclog_header_cp_async_nan) { + synclog_print_prefix("cp_async_nan", at); + at += synclog_length_cp_async_nan; + uint64_t gmem_addr = synclog_buf[at-3]; + gmem_addr += (uint64_t)synclog_buf[at-2] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_zfill) { + if (header == synclog_header_cp_async_zfill) { + synclog_print_prefix("cp_async_zfill", at); + at += synclog_length_cp_async_zfill; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async) { + if (header == synclog_header_cp_async) { + synclog_print_prefix("cp_async", at); + at += synclog_length_cp_async; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_load) { + if (header == synclog_header_tma_load) { + synclog_print_prefix("tma_load", at); + at += synclog_length_tma_load; + uint64_t gmem_int_desc = synclog_buf[at-4]; + gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32; + printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store) { + if (header == synclog_header_tma_store) { + synclog_print_prefix("tma_store", at); + at += synclog_length_tma_store; + uint64_t gmem_int_desc = synclog_buf[at-3]; + gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32; + printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store_arrive) { + if (header == synclog_header_tma_store_arrive) { + synclog_print_prefix("tma_store_arrive", at); + at += synclog_length_tma_store_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_tma_store_wait) { + if (header == synclog_header_tma_store_wait) { + synclog_print_prefix("tma_store_wait", at); + at += synclog_length_tma_store_wait; + printf("count=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_arrive) { + if (header == synclog_header_warpgroup_arrive) { + synclog_print_prefix("warpgroup_arrive", at); + at += synclog_length_warpgroup_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_warpgroup_wait) { + if (header == synclog_header_warpgroup_wait) { + synclog_print_prefix("warpgroup_wait", at); + at += synclog_length_warpgroup_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_commit_batch) { + if (header == synclog_header_warpgroup_commit_batch) { + synclog_print_prefix("warpgroup_commit_batch", at); + at += synclog_length_warpgroup_commit_batch; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_reg_smem) { + if (header == synclog_header_wgmma_reg_smem) { + synclog_print_prefix("wgmma_reg_smem", at); + at += synclog_length_wgmma_reg_smem; + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_smem_smem) { + if (header == synclog_header_wgmma_smem_smem) { + synclog_print_prefix("wgmma_smem_smem", at); + at += synclog_length_wgmma_smem_smem; + synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " "); + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cpasync_barrier_arrive) { + if (header == synclog_header_cpasync_barrier_arrive) { + synclog_print_prefix("cpasync_barrier_arrive", at); + at += synclog_length_cpasync_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + asm volatile ("brkpt;\n" ::); + } + if (synclog_buf[0] >= synclog_cap) { + printf( + "synclog was truncated (exceeded capacity of %lu bytes)\n", + (synclog_cap - 1) * sizeof(uint32_t) + ); + } + printf("synclog end\n"); + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncthreads +#define __syncthreads() do {\ + cutlass::arch::synclog_emit_syncthreads(__LINE__);\ + __syncthreads();\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncwarp +#define __syncwarp(...) do {\ + cutlass::arch::synclog_emit_syncwarp(__LINE__);\ + __syncwarp(__VA_ARGS__);\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 499d45c724..62e9469497 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/functional.h" #include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -49,6 +50,23 @@ template < > struct Array; +namespace detail { + +template +struct is_Array : platform::false_type {}; + +template < + typename T, + int N, + bool RegisterSized +> +struct is_Array > : platform::true_type {}; + +template +constexpr bool is_Array_v = is_Array::value; + +} // namespace detail + //////////////////////////////////////////////////////////////////////////////////////////////////// /// Defines the size of an Array<> in bits @@ -803,111 +821,14 @@ struct reciprocal_approximate_ftz> { } }; -template -struct maximum, false> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct maximum, true> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct minimum, false> { - - CUTLASS_HOST_DEVICE - static T scalar_op(T const &lhs, T const &rhs) { - return (rhs < lhs ? rhs : lhs); - } +template +struct maximum, PropagateNaN> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; - minimum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -921,7 +842,7 @@ struct minimum, false> { Array operator()(Array const &lhs, T const &scalar) const { Array result; - minimum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -935,7 +856,7 @@ struct minimum, false> { Array operator()(T const &scalar, Array const &rhs) const { Array result; - minimum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -946,8 +867,8 @@ struct minimum, false> { } }; -template -struct minimum, true> { +template +struct minimum, PropagateNaN> { CUTLASS_HOST_DEVICE static T scalar_op(T const &lhs, T const &rhs) { @@ -958,7 +879,7 @@ struct minimum, true> { Array operator()(Array const &lhs, Array const &rhs) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -972,7 +893,7 @@ struct minimum, true> { Array operator()(Array const &lhs, T const &scalar) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -986,7 +907,7 @@ struct minimum, true> { Array operator()(T const &scalar, Array const &rhs) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -2030,8 +1951,8 @@ struct multiply_add_relu0, Array, Array> } }; -template -struct minimum, false> { +template +struct minimum, PropagateNaN> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; @@ -2043,25 +1964,27 @@ struct minimum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmin2(lhs_ptr[i], rhs_ptr[i]); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmin( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + minimum mn; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); + result[i] = mn(lhs[i],rhs[i]); } #endif @@ -2079,24 +2002,26 @@ struct minimum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i]) + : __hmin2(lhs_pair, rhs_ptr[i]); } if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmin( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + minimum mn; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (rhs[i] < lhs ? rhs[i] : lhs); + result[i] = mn(lhs, rhs[i]); } #endif @@ -2114,24 +2039,26 @@ struct minimum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair) + : __hmin2(lhs_ptr[i], rhs_pair); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hmin( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else + minimum mn; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (rhs < lhs[i] ? rhs : lhs[i]); + result[i] = mn(lhs[i], rhs); } #endif @@ -2139,8 +2066,8 @@ struct minimum, false> { } }; -template -struct maximum, false> { +template +struct maximum, PropagateNaN> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; @@ -2152,25 +2079,27 @@ struct maximum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmax2(lhs_ptr[i], rhs_ptr[i]); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmax( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + maximum mx; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); + result[i] = mx(lhs[i], rhs[i]); } #endif @@ -2188,24 +2117,26 @@ struct maximum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i]) + : __hmax2(lhs_pair, rhs_ptr[i]); } if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmax( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + maximum mx; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (lhs < rhs[i] ? rhs[i] : lhs); + result[i] = mx(lhs, rhs[i]); } #endif @@ -2223,24 +2154,26 @@ struct maximum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair) + : __hmax2(lhs_ptr[i], rhs_pair); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hmax( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); + __half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else + maximum mx; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (lhs[i] < rhs ? rhs : lhs[i]); + result[i] = mx(lhs[i], rhs); } #endif diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 50506c73be..5af6d3ab80 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -190,6 +190,12 @@ struct alignas(2) bfloat16_t { return (float(*this) != 0.0f); } + /// Bitcasts to CUDA's bf16 type + CUTLASS_DEVICE + __nv_bfloat16 to_nv_bfloat16() const { + return reinterpret_cast<__nv_bfloat16 const &>(storage); + } + /// Obtains raw bits CUTLASS_HOST_DEVICE uint16_t raw() const { @@ -321,9 +327,9 @@ bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { // /////////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) namespace std { -#if !defined(__CUDACC_RTC__) /// Numeric limits template <> struct numeric_limits { @@ -378,9 +384,78 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } }; -#endif } // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Forward Declaration +template +struct numeric_limits; + +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; + +} // namespace platform +} // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -394,114 +469,190 @@ namespace cutlass { CUTLASS_HOST_DEVICE bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) == float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) != float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) < float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) <= float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) > float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) >= float(rhs); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) + float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator-(bfloat16_t const& lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hneg(lhs.to_nv_bfloat16())); +#else return bfloat16_t(-float(lhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) - float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) * float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) / float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) + float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) - float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) * float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) / float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator++(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); ++tmp; lhs = bfloat16_t(tmp); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator--(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); --tmp; lhs = bfloat16_t(tmp); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t operator++(bfloat16_t & lhs, int) { bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); tmp++; lhs = bfloat16_t(tmp); +#endif return ret; } CUTLASS_HOST_DEVICE bfloat16_t operator--(bfloat16_t & lhs, int) { bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); tmp--; lhs = bfloat16_t(tmp); +#endif return ret; } diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 3d140eaa84..a0fa22b6bb 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -172,6 +172,7 @@ struct ClusterLauncher { "And ClusterDims = " "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + cutlass::arch::synclog_setup(); cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); Return_Status(status); #else diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 84fd37b47e..78862b0a09 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -41,8 +41,8 @@ #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" +#include "cutlass/conv/detail.hpp" #include "cutlass/conv/convolution.h" -#include "cutlass/conv/convnd_problem_shape.hpp" #include "cutlass/conv/dispatch_policy.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/util/packed_stride.hpp" @@ -103,6 +103,8 @@ struct CollectiveConv< using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename cutlass::PipelineState; + + using ProblemShape = ConvProblemShape; // TODO: move pipeline mode tiling into the collective setup phase instead static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); @@ -143,7 +145,7 @@ struct CollectiveConv< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; @@ -162,8 +164,6 @@ struct CollectiveConv< // Host side kernel arguments struct Arguments { - using ProblemShape = ConvProblemShape; - ProblemShape problem_shape{}; ElementA const* ptr_A{nullptr}; ElementB const* ptr_B{nullptr}; }; @@ -175,7 +175,7 @@ struct CollectiveConv< // Get tma_load_a instantce. template static constexpr auto - get_tma_load_a_instance(TensorA const& tensor_a, typename Arguments::ProblemShape const& problem_shape) { + get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) { if constexpr (is_im2col_A) { // compute the upper and lower corners based on the conv padding auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); @@ -218,7 +218,7 @@ struct CollectiveConv< // Get tma_load_b instantce. template static constexpr auto - get_tma_load_b_instance(TensorB const& tensor_b, typename Arguments::ProblemShape const& problem_shape) { + get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) { // TMA im2col mode for tensor B in wgrad kernel. if constexpr (is_im2col_B) { // compute the upper and lower corners based on the conv padding @@ -250,24 +250,25 @@ struct CollectiveConv< } } +public: + + // Performs im2col transformations on the input of type ConvProblemShape static constexpr auto - get_problem_shape_MNKL(typename Arguments::ProblemShape const& problem_shape) { + get_problem_shape_MNKL(ProblemShape const& problem_shape) { + if constexpr (is_im2col_A || is_im2col_B) { // transformation + im2col linearization - return problem_shape.get_linearized_problem_shape_MNKL(); + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); } else { // transformation - return problem_shape.get_transformed_problem_shape_MNKL(); + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); } } -public: - // Device side kernel params struct Params { - using _Submode = decltype(take<0,NumTensorDimensions-1>(typename Arguments::ProblemShape::TensorExtent{})); - using ProblemShape = decltype(get_problem_shape_MNKL(typename Arguments::ProblemShape{})); + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); // Assumption: StrideA is congruent with Problem_MK // Select TMA load type according to convolution operator. @@ -294,7 +295,6 @@ struct CollectiveConv< // Members TMA_A tma_load_a; TMA_B tma_load_b; - ProblemShape problem_shape; uint32_t tma_transaction_bytes = TmaTransactionBytes; }; @@ -304,19 +304,19 @@ struct CollectiveConv< // Lowers the host side user facing arguments to the kernel facing lauch params static constexpr Params - to_underlying_arguments(Arguments const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple // tma desc creation depends on the original untransformed domain. // A extents. - auto shape_A_orig = args.problem_shape.get_shape_A(); + auto shape_A_orig = problem_shape.get_shape_A(); // B extents. - auto shape_B_orig = args.problem_shape.get_shape_B(); + auto shape_B_orig = problem_shape.get_shape_B(); // Fill inferred cute strides from flat stride arrays - auto dA = make_cute_packed_stride(StrideA{}, args.problem_shape.stride_A, ConvOp); - auto dB = make_cute_packed_stride(StrideB{}, args.problem_shape.stride_B, ConvOp); + auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); auto ptr_A = reinterpret_cast(args.ptr_A); auto ptr_B = reinterpret_cast(args.ptr_B); @@ -324,20 +324,17 @@ struct CollectiveConv< Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); - auto tma_load_a = get_tma_load_a_instance(tensor_a, args.problem_shape); - auto tma_load_b = get_tma_load_b_instance(tensor_b, args.problem_shape); - - auto problem_shape_mnkl = get_problem_shape_MNKL(args.problem_shape); + auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape); + auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape); return { tma_load_a, tma_load_b, - problem_shape_mnkl, TmaTransactionBytes }; } - - template + + template static bool can_implement( ProblemShape const& problem_shape, @@ -345,14 +342,14 @@ struct CollectiveConv< // Activation and Filter channel mode extents much match bool implementable = true; // channel mode is major - implementable &= args.problem_shape.stride_A[NumTensorDimensions-1] == 1; - implementable &= args.problem_shape.stride_B[NumTensorDimensions-1] == 1; + implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1; + implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1; constexpr int tma_alignment_bits = 128; // A extents. - auto shape_A_orig = args.problem_shape.get_shape_A(); + auto shape_A_orig = problem_shape.get_shape_A(); // B extents. - auto shape_B_orig = args.problem_shape.get_shape_B(); + auto shape_B_orig = problem_shape.get_shape_B(); constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(shape_A_orig, StrideA{}); constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; @@ -375,61 +372,6 @@ struct CollectiveConv< return false; } - if (is_im2col_A || is_im2col_B) { - // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] - constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); - } - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - } - - // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) - if constexpr (ConvOp == conv::Operator::kWgrad) { - - const auto & input_shape = problem_shape.shape_A; - const auto & input_stride = problem_shape.stride_A; - - implementable &= input_stride[ProblemShape::RankT - 1] == 1; - int input_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - input_shape_size *= input_shape[i + 1]; - implementable &= input_stride[i] == input_shape_size; - } - - const auto & output_shape = problem_shape.shape_C; - const auto & output_stride = problem_shape.stride_C; - - implementable &= output_stride[ProblemShape::RankT - 1] == 1; - int output_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - output_shape_size *= output_shape[i + 1]; - implementable &= output_stride[i] == output_shape_size; - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); - return false; - } - } - - // Conv kernels only support cross correlation mode currently. - implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); - return false; - } - if (problem_shape.groups > 1) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); return false; @@ -445,24 +387,53 @@ struct CollectiveConv< cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) + /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) + /// The rest of the tensors can be specified as needed by this collective. + /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with + /// StrideA and StrideB set up for TMA + template + CUTLASS_DEVICE auto + load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ + //load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k) + Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k) + + // Make tiled views, defer the slice + Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + return cute::make_tuple(gA_mk, gB_nk); + } + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < - class TensorA, class TMA_LOAD_A, - class TensorB, class TMA_LOAD_B, - class KTileIterator + class TensorA, class TensorB, + class KTileIterator, class BlockCoord > CUTLASS_DEVICE void - load(MainloopPipeline pipeline, - PipelineState smem_pipe_producer_state, - TensorA const& gA, TMA_LOAD_A& tma_load_a, - TensorB const& gB, TMA_LOAD_B& tma_load_b, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { - int lane_predicate = cute::elect_one_sync(); + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_producer_state, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -470,11 +441,19 @@ struct CollectiveConv< // // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + auto [gA_mk, gB_nk] = load_inputs; + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -518,8 +497,9 @@ struct CollectiveConv< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state); int write_stage = smem_pipe_producer_state.index(); - copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; // Advance smem_pipe_producer_state diff --git a/include/cutlass/conv/convnd_problem_shape.hpp b/include/cutlass/conv/convnd_problem_shape.hpp index 0172120538..ffcc547fbd 100644 --- a/include/cutlass/conv/convnd_problem_shape.hpp +++ b/include/cutlass/conv/convnd_problem_shape.hpp @@ -43,6 +43,7 @@ #include #endif + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::conv { @@ -54,15 +55,17 @@ namespace cutlass::conv { // Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. template < conv::Operator ConvOp_, - int NumSpatialDimensions + int NumSpatialDimensions_ > struct ConvProblemShape { // // Alias types for members // - static constexpr int RankS = NumSpatialDimensions; - static constexpr int RankT = NumSpatialDimensions + 2; + + static constexpr int RankS = NumSpatialDimensions_; + static constexpr int RankT = NumSpatialDimensions_ + 2; static constexpr conv::Operator ConvOp = ConvOp_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; using SpatialExtent = cute::array; using TensorExtent = cute::array; using TensorStride = cute::array; @@ -352,71 +355,6 @@ struct ConvProblemShape { } } - // Get problem shape MNKL according to following table: - // | | Fprop | Dgrad | Wgrad | - // | ---- | --------- | -------- | -------- | - // | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | - // | Shape_N | (K) | (C) | (C,S,R,T) | - // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | - // | Shape_L | _1 | (V,U,O) | _1 | - CUTLASS_HOST_DEVICE - constexpr auto - get_transformed_problem_shape_MNKL() const { - using cute::insert; - using cute::make_shape; - using cute::reverse; - using cute::take; - - if constexpr (ConvOp == conv::Operator::kWgrad) { - auto M_xformed = shape_C[0]; - auto N_xformed = reverse(take<1, RankT>(shape_C)); - auto K_xformed = reverse(take<0, RankT - 1>(shape_A)); - auto L_xformed = cute::Int<1>{}; - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - else if constexpr (ConvOp == conv::Operator::kFprop){ - auto M_xformed = reverse(take<0, RankT - 1>(shape_C)); - auto N_xformed = shape_C[RankT - 1]; - auto K_xformed = reverse(take<1, RankT>(shape_B)); - auto L_xformed = cute::Int<1>{}; - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - auto L_xformed = reverse(traversal_stride); // (V,U,O) - auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(shape_C)), L_xformed); - auto N_xformed = shape_C[RankT - 1]; - // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] - auto K_xformed = insert<0>( - (reverse(take<1,RankT - 1>(shape_B))), - shape_B[0]); - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - } - - // Assuming im2col linearization - // Get problem shape MNKL according to following table: - // | | Fprop | Dgrad | Wgrad | - // | ---- | --------- | -------- | -------- | - // | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | - // | Shape_N | (K) | (C) | (C,S,R,T) | - // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | - // | Shape_L | _1 | (V*U*O) | _1 | - CUTLASS_HOST_DEVICE - constexpr auto - get_linearized_problem_shape_MNKL() const { - auto [M, N, K, L] = get_transformed_problem_shape_MNKL(); - - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { - return cute::make_shape(cute::product(M), N, K, cute::product(L)); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - return cute::make_shape(M, N, cute::product(K), L); - } - } - // Get A extents. // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) @@ -578,9 +516,7 @@ struct ConvProblemShape { // calculate n,z,p,q,k. // a helper lambda to compute a single spatial extent of the nzpqk tensor auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { - auto tmp = act_ext + pad_total - ((filter_ext -1) * dilation + 1); - CUTLASS_ASSERT(tmp % tstride == 0); - return 1 + tmp / tstride; + return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; }; shape_xformed_act[0] = shape_act[0]; // Activation N extent diff --git a/include/cutlass/conv/detail.hpp b/include/cutlass/conv/detail.hpp new file mode 100644 index 0000000000..3e4173569c --- /dev/null +++ b/include/cutlass/conv/detail.hpp @@ -0,0 +1,137 @@ + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + // Helper function to get the problem shape +template +auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) { + return T::get_problem_shape_MNKL(problem_shape); +} + +template +ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) { + return problem_shape; +} + +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | +// | Shape_L | _1 | (V,U,O) | _1 | + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) { + return problem_shape; +} + + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + using cute::insert; + using cute::make_shape; + using cute::reverse; + using cute::take; + + constexpr int RankT = SpatialDim + 2; + + if constexpr (ConvOp == conv::Operator::kWgrad) { + auto M_xformed = problem_shape.shape_C[0]; + auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C)); + auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kFprop){ + auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C)); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O) + auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] + auto K_xformed = insert<0>( + (reverse(take<1,RankT - 1>(problem_shape.shape_B))), + problem_shape.shape_B[0]); + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } +} + +// Assuming im2col linearization +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | +// | Shape_L | _1 | (V*U*O) | _1 | +template +CUTLASS_HOST_DEVICE +constexpr auto +get_linearized_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + + auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape); + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + return cute::make_shape(cute::product(M), N, K, cute::product(L)); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cute::make_shape(M, N, cute::product(K), L); + } + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 0472b898c2..193f8d8854 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -61,7 +61,7 @@ template class ConvUniversalAdapter { public: - using ConvKernel = ConvKernel_; + using ConvKernel = GetUnderlyingKernel_t; using TileShape = typename ConvKernel::TileShape; using ElementA = typename ConvKernel::ElementA; using ElementB = typename ConvKernel::ElementB; @@ -76,7 +76,7 @@ class ConvUniversalAdapter // Tease out meta-information about the conv algorithm static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp; - static constexpr int NumSpatialDimensions = ConvKernel::NumSpatialDimensions; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! using OperatorClass = cute::conditional_t< @@ -121,13 +121,13 @@ class ConvUniversalAdapter static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; // Inspect TiledCopy for A and B to compute the alignment size - static int constexpr kAlignmentA = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); - static int constexpr kAlignmentB = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); - static int constexpr kAlignmentC = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); - static int constexpr kAlignmentD = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; @@ -297,8 +297,9 @@ class ConvUniversalAdapter Status launch_result; // Use extended launch API only for mainloops that use it if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { - constexpr bool is_static_1x1x1 = cute::is_static_v and - cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); diff --git a/include/cutlass/conv/device/direct_convolution.h b/include/cutlass/conv/device/direct_convolution.h index 84953d8036..43ab94b5fc 100644 --- a/include/cutlass/conv/device/direct_convolution.h +++ b/include/cutlass/conv/device/direct_convolution.h @@ -211,6 +211,7 @@ class DirectConvolution { dim3 grid = ReorderKernel::get_grid_shape(params_); dim3 block = ReorderKernel::get_block_shape(); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); } @@ -229,6 +230,7 @@ class DirectConvolution { if (status != cudaSuccess) return Status::kErrorInternal; + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 62c7e8715d..a1cb06e98f 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -53,7 +53,7 @@ template class ImplicitGemmConvolution { public: - using UnderlyingKernel = ImplicitGemmKernel_; + using UnderlyingKernel = GetUnderlyingKernel_t; using ElementA = typename UnderlyingKernel::ElementA; using LayoutA = typename UnderlyingKernel::LayoutA; @@ -103,7 +103,6 @@ class ImplicitGemmConvolution { /// Determines whether the Implicit GEMM can execute the given problem. static Status can_implement(Arguments const &args) { - // dispatch to iterators Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); if (Status::kSuccess != status) { @@ -164,9 +163,8 @@ class ImplicitGemmConvolution { // check for unsupported problem sizes for strided dgrad / deconv implementation if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && kStrideSupport == conv::StrideSupport::kStrided) { - // split-k (serial or parallel) is not supported for strided dgrad / deconv - if(args.problem_size.split_k_slices > 1) { + if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) { return Status::kErrorNotSupported; } @@ -291,7 +289,7 @@ class ImplicitGemmConvolution { } /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { ThreadblockSwizzle threadblock_swizzle; @@ -311,7 +309,7 @@ class ImplicitGemmConvolution { void* kernel_params[] = {¶ms_}; launch_result = cuda_adapter->launch( - grid, dim3(1,1,1), block, smem_size, stream, kernel_params, 0 + grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index ); } else { @@ -319,6 +317,7 @@ class ImplicitGemmConvolution { } } else { + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); } @@ -333,20 +332,20 @@ class ImplicitGemmConvolution { } /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - return run(stream, cuda_adapter); + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + return run(stream, cuda_adapter, kernel_index); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { - status = run(stream, cuda_adapter); + status = run(stream, cuda_adapter, kernel_index); } return status; diff --git a/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h index 1eb0d5600e..265156cc5b 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h @@ -231,6 +231,7 @@ class ImplicitGemmConvolutionFusion { int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index 039f4539c4..b8b5eb2bff 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -37,6 +37,8 @@ #include "cute/layout.hpp" #include "cute/numeric/integral_constant.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + ////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// @@ -48,7 +50,7 @@ namespace cutlass::conv { // // Policies for categorical dispatch of mainloop against kernel grid schedules // -struct KernelImplicitTmaWarpSpecializedSm90 { }; +struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { }; struct KernelImplicitTmaWarpSpecializedSm90Cooperative { }; struct KernelImplicitTmaWarpSpecializedSm90Pingpong { }; @@ -84,3 +86,5 @@ struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv + +////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/conv_universal.hpp b/include/cutlass/conv/kernel/conv_universal.hpp index 9d98dc9d96..23ccea2f8f 100644 --- a/include/cutlass/conv/kernel/conv_universal.hpp +++ b/include/cutlass/conv/kernel/conv_universal.hpp @@ -30,6 +30,7 @@ **************************************************************************************************/ #pragma once +#include "cutlass/conv/convnd_problem_shape.hpp" #include "cutlass/detail/dependent_false.hpp" //////////////////////////////////////////////////////////////////////////////// @@ -43,6 +44,7 @@ namespace cutlass::conv::kernel { * a composition of a collective mainloop and a collective epilogue. **/ template < + class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileSchedulerTag_ = void, diff --git a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp index 95780bf84e..657ac6b3ec 100644 --- a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -37,9 +37,12 @@ #include "cute/tensor.hpp" #include "cute/arch/cluster_sm90.hpp" +#include "cutlass/conv/detail.hpp" #include "cutlass/conv/convolution.h" #include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/kernel/tile_scheduler.hpp" /////////////////////////////////////////////////////////////////////////////// @@ -49,365 +52,25 @@ namespace cutlass::conv::kernel { /////////////////////////////////////////////////////////////////////////////// template < + class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class TileSchedulerTag_ + class TileScheduler_ > class ConvUniversal< + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - TileSchedulerTag_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; - static_assert(ArchTag::kMinComputeCapability >= 90); - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - using TileSchedulerTag = TileSchedulerTag_; - static_assert(cute::is_void_v, - "TMA warp-specialized kernel does not support specializing the tile scheduler."); - using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< - TileSchedulerTag, ArchTag, TileShape, ClusterShape>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - - // Kernel level shared memory storage - struct SharedStorage { - union TensorStorage { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 1; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Host facing host arguments - struct Arguments { - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel device entry point API - struct Params { - MainloopParams mainloop; - EpilogueParams epilogue; - }; - - // - // Methods - // - - // Map user facing arguments to device facing params - static Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace); - auto problem_shape_MNKL = args.mainloop.problem_shape.get_transformed_problem_shape_MNKL(); - - return { - mainloop_params, - CollectiveEpilogue::to_underlying_arguments(problem_shape_MNKL, args.epilogue, workspace) - }; - } - - // Given arguemnts, returns true if the kernel can successfully compute upon them. False otherwise. - static bool - can_implement(Arguments const& args) { - bool implementable = true; - implementable &= CollectiveMainloop::can_implement(args.mainloop.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.mainloop.problem_shape.get_transformed_problem_shape_MNKL(), args.epilogue); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - return cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( - params.mainloop.problem_shape, TileShape{}, ClusterShape{}); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif - - enum class WarpGroupRole { - Producer = 0, - Consumer = 1, - }; - - enum class ProducerWarpRole { - MainloopEpilogue = 0, - Warp1 = 1, - Warp2 = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [&] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Separate out problem shape for convenience - auto problem_shape_MNKL = append<4>(params.mainloop.problem_shape, _1{}); - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mk = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M, K)); - Tensor mB_nk = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N, K)); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - auto cta_tile_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - TiledMma tiled_mma; - - // Make tiled views, defer the slice - Tensor gA_mk = local_tile(mA_mk, cta_tile_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) - Tensor gB_nk = local_tile(mB_nk, cta_tile_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) - - // Compute m_coord, n_coord, and l_coord with their post-tiled shapes - auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mk)); - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nk), compact_col_major(shape<2>(gB_nk))); - - // The output shape M is linearized so the output coord M here should also be linearized. - auto output_tile_coord = make_coord(int(blockIdx.x), n_coord, _, Int<0>{}); - - // Slice with m_coord and n_coord - Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) - Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) - - // Get pipeline iterators and increments from tensor shapes - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - auto k_tile_count = size<2>(gA); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; - - // Wait for all thread blocks in Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) { - if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - k_tile_iter, k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting mainloop pipeline state for the pipeline drain - mainloop_pipe_producer_state.advance(k_tile_count); - // Make sure mainloop consumer has been waited upon before issuing epilogue load - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_producer_load_needed()) { - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - cta_tile_shape, - output_tile_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue - ); - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } - } - } - else if (warp_group_role == WarpGroupRole::Consumer) { - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(cta_tile_shape)); // (MMA,MMA_M,MMA_N) - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - k_tile_count - ); - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - cta_tile_shape, - output_tile_coord, - accumulators, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state_next, - epi_store_pipeline, - epi_store_pipe_producer_state_next - ); - } - } -}; - + TileScheduler_, + cute::enable_if_t> +> : public cutlass::gemm::kernel::GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_ +> +{}; /////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv::kernel + diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index f9ff723ce1..1c8f56a652 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -82,6 +82,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// + #if !defined(__CUDACC_RTC__) #include @@ -152,6 +153,7 @@ CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); #endif // !defined(__CUDACC_RTC__) + ///////////////////////////////////////////////////////////////////////////////////////////////// /// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index f396528307..e12616a201 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -35,6 +35,7 @@ #pragma once +#include "cutlass/arch/synclog.hpp" #include "cutlass/detail/helper_macros.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp index d3c4c04b74..a4b288e7c9 100644 --- a/include/cutlass/detail/collective.hpp +++ b/include/cutlass/detail/collective.hpp @@ -31,7 +31,6 @@ #pragma once #include "cute/container/tuple.hpp" - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 429e5c2f06..216ba40285 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -30,13 +30,17 @@ **************************************************************************************************/ #pragma once +#include "cute/layout.hpp" +#include "cute/pointer_sparse.hpp" // cute::is_sparse +#include "cute/swizzle.hpp" // cute::Swizzle +#include "cute/swizzle_layout.hpp" // cute::detail::get_swizzle_portion +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_tma.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" #include "cutlass/numeric_types.h" +#include "cutlass/detail/collective.hpp" -#include "cute/layout.hpp" -#include "cute/util/type_traits.hpp" -#include "cute/arch/copy_sm90_tma.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::detail { @@ -199,9 +203,16 @@ template constexpr auto stride_to_layout_tag_A() { + using InternalStrideA = cute::remove_pointer_t; if constexpr (is_major<0, StrideA>()) { // M major return layout::ColumnMajor{}; } + // Specialize for sparse layout + else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && + cute::rank(cute::get<1>(InternalStrideA{})) == 2 && + cute::is_same_v(InternalStrideA{}))>>) { + return layout::ColumnMajor{}; + } else { // K major return layout::RowMajor{}; } @@ -309,6 +320,10 @@ get_alignment_count_from_gmem_tiled_copy() { else { // For TMA tiled copies, we know the alignment has to be 128 bits if constexpr (is_tma_copy_engine()) { + // For sparse MMA, alignment in logical elements is increased by sparsity factor + if constexpr (cute::is_sparse_v) { + return 128 / sizeof_bits::value * ElementMma::sparsity; + } return 128 / sizeof_bits::value; } else { diff --git a/include/cutlass/detail/mma.hpp b/include/cutlass/detail/mma.hpp index 058f5fd3ea..0e491b9c40 100644 --- a/include/cutlass/detail/mma.hpp +++ b/include/cutlass/detail/mma.hpp @@ -42,6 +42,11 @@ namespace cutlass::detail { template struct IsSparseTensorOp : cute::false_type { }; +// TiledMma for sparse must have ValTypeE +template +struct IsSparseTensorOp> + : cute::true_type { }; + // The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. template struct get_operator_class { diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index ba875a757a..7af5d96cf6 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -56,6 +56,13 @@ namespace cutlass { +template struct Type2Type { using type=T; }; +// using the simple type to replace the complex type to reduce this symbol size +template struct GetUnderlyingKernel : public Type2Type {}; +template class Wrapper > struct GetUnderlyingKernel> : public Wrapper {}; +template using GetUnderlyingKernel_t = typename GetUnderlyingKernel::type; + + //////////////////////////////////////////////////////////////////////////////// /// Generic CUTLASS kernel template. @@ -71,6 +78,7 @@ void Kernel(typename Operator::Params params) { Operator op; op(params, *shared_storage); + cutlass::arch::synclog_print(); } @@ -85,6 +93,8 @@ void Kernel2(typename Operator::Params params) { reinterpret_cast(SharedStorageBase); Operator::invoke(params, *shared_storage); + cutlass::arch::synclog_print(); + } @@ -107,6 +117,8 @@ void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) extern __shared__ char smem[]; Operator op; op(params, smem); + cutlass::arch::synclog_print(); + } //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 90a600028c..759591b5dc 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -258,7 +258,7 @@ struct Sm90TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; - + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; @@ -273,6 +273,9 @@ struct Sm90TmaBuilderImpl { // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed currently. + using CopyOpR2R = void; // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination @@ -300,7 +303,8 @@ struct Sm90TmaBuilderImpl { CopyOpS2G, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_store_op_for_accumulator()), - CopyAtomC + CopyAtomC, + CopyOpR2R >; }; @@ -386,6 +390,7 @@ struct AuxStoreDescriptor { // No-smem builder template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -402,7 +407,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -452,6 +457,7 @@ struct CollectiveBuilder< // Tma warp-specialized builder template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -468,7 +474,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -513,6 +519,7 @@ public: // Auto builder template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -528,7 +535,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -552,7 +559,7 @@ private: using EpilogueSchedule = NoSmemWarpSpecialized; using _CollectiveBuilder = CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -574,6 +581,7 @@ public: // DEPRECATED Tma warp-specialized builder for elementwise fusion template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -591,7 +599,7 @@ template < struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -618,7 +626,7 @@ public: using CollectiveOp = typename CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -637,6 +645,7 @@ public: // DEPRECATED Tma warp-specialized builder for bias + elementwise fusion template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -654,7 +663,7 @@ template < struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]] CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -714,6 +723,9 @@ private: // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed. + using CopyOpR2R = void; public: using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< @@ -733,7 +745,8 @@ public: SM90_TMA_STORE, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_store_op_for_accumulator()), - CopyAtomC + CopyAtomC, + CopyOpR2R >; }; @@ -741,6 +754,7 @@ public: // since swapping NNN kernels input matrix and transposing its output at the same time then // we can get TTN kernel. template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -756,7 +770,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index a14696b2f8..d54cd0a8f7 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -30,6 +30,9 @@ **************************************************************************************************/ #pragma once +#include // cute::DefaultCopy +#include // cute::is_base_of_v + #include "cutlass/detail/dependent_false.hpp" #include "cutlass/epilogue/fusion/callbacks.hpp" @@ -100,7 +103,7 @@ struct CallbacksBuilder< TileShape_MNK, EpilogueTile_MN, ElementAccumulator, - cute::enable_if_t> + cute::enable_if_t> > { using Callbacks = FusionCallbacks; }; diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index f8179b0a0e..8fb1a9588b 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -53,11 +53,19 @@ class CollectiveEpilogue { ///////////////////////////////////////////////////////////////////////////////////////////////// #include "detail.hpp" + +// +// Gemm +// #include "default_epilogue.hpp" #include "default_epilogue_array.hpp" #include "epilogue_tensor_broadcast.hpp" #include "sm70_epilogue_vectorized.hpp" +#include "sm70_epilogue_vectorized_array.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" #include "sm90_epilogue_array_tma_warpspecialized.hpp" +// +// Conv +// ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index b96b13fecc..a6e13bc7e9 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -199,6 +199,14 @@ struct IsThreadEpilogueOpWithActivation +struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; + +template +struct IsThreadEpilogueOpWithElementwiseArguments< + ThreadEpilogueOp, + cute::void_t> : cute::true_type {}; + // Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels template class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { @@ -430,7 +438,8 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { // Dummy methods to perform different parts of TMA/Tensormap modifications - template + template CUTLASS_DEVICE void tensormaps_perform_update( diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 69170f75ea..a8083dab1d 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -46,6 +46,25 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + class StrideC, + class StrideD, + class ThreadEpilogueOp, + class SmemLayout, + class CopyAtomR2S, + class TiledCopyS2R, + class CopyAtomR2G, + class EpilogueScheduleType = EpilogueSimtVectorized, + class Enable = void +> +class Epilogue { + static_assert(cute::is_same_v || + cute::is_same_v, + "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Vectorized /// Applies an element wise operation to all elements within the fragment /// and writes it out to destination storage. /// @@ -61,9 +80,22 @@ template < class SmemLayout_, class CopyAtomR2S_, class TiledCopyS2R_, - class CopyAtomR2G_ + class CopyAtomR2G_, + class EpilogueScheduleType_ > -class Epilogue { +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { public: // // Type Aliases @@ -78,15 +110,17 @@ class Epilogue { using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; using SmemLayout = SmemLayout_; using CopyAtomR2S = CopyAtomR2S_; using TiledCopyS2R = TiledCopyS2R_; using CopyAtomR2G = CopyAtomR2G_; - static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using GmemTiledCopyC = void; + using GmemTiledCopyD = CopyAtomR2G; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + static constexpr bool IsEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + using StrideBias = cute::conditional_t(), Stride<_1,_0,int64_t>, Stride<_0,_1,int64_t>>; static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -96,9 +130,35 @@ class Epilogue { cute::array_aligned> smem_epilogue; }; + static constexpr bool IsActHasArgs = detail::IsThreadEpilogueOpWithElementwiseArguments::value; + // Host side epilogue arguments + template + struct ThreadEpilogueOpArguments { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + }; + + template + struct ThreadEpilogueOpArguments< + ThreadEpiOp, + cute::enable_if_t::value>> { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + }; + struct Arguments { - typename ThreadEpilogueOp::Params thread{}; + ThreadEpilogueOpArguments thread{}; + using StrideBias = decltype(thread.dBias); ElementC const* ptr_C = nullptr; StrideC dC{}; ElementD* ptr_D = nullptr; @@ -106,7 +166,32 @@ class Epilogue { }; // Device side epilogue params - using Params = Arguments; + template + struct ParamsType { + typename ThreadEpiOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + template + struct ParamsType< + ThreadEpiOp, + cute::enable_if_t::value>> { + typename ThreadEpiOp::Params thread{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + using Params = ParamsType; // // Methods @@ -117,8 +202,36 @@ class Epilogue { to_underlying_arguments( [[maybe_unused]] ProblemShape const& _, Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; + [[maybe_unused]] void* workspace) { + typename ThreadEpilogueOp::Params thread_op_args; + thread_op_args.alpha = args.thread.alpha; + thread_op_args.beta = args.thread.beta; + thread_op_args.alpha_ptr = args.thread.alpha_ptr; + thread_op_args.beta_ptr = args.thread.beta_ptr; + + if constexpr (IsActHasArgs) { + return { + thread_op_args, + args.thread.activation, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } + else { + return { + thread_op_args, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } } template @@ -169,8 +282,7 @@ class Epilogue { TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, - char* smem_buf) - { + char* smem_buf) { using namespace cute; using X = Underscore; @@ -192,88 +304,112 @@ class Epilogue { auto L = get<3>(problem_shape_mnkl); // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), params.dBias); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + // Construct a tensor in SMEM that we can partition for rearranging data SharedStorage& storage = *reinterpret_cast(smem_buf); - Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) - // Partition sC to match the accumulator partitioning + // Partition sAcc to match the accumulator partitioning auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); - auto tC = tiled_r2s.get_thread_slice(thread_idx); - Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Tile gD and gC by the shape of SmemLayout first - auto tile = make_shape(size<0>(sC), size<1>(sC)); + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gBiast = flat_divide(gBias, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - // Partition sC, gC, and gD for the output + // Partition sAcc, gC, and gD for the output auto tiled_s2r = TiledCopyS2R{}; - auto tD = tiled_s2r.get_thread_slice(thread_idx); - Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBiast); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) // Allocate intermediate registers on the dst tensors - Tensor tDrC = make_tensor(take<0,3>(shape(tDgC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tDrD = make_tensor(shape(tDrC)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rC = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rBias = make_tensor_like(tSR_gBias); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) // Repeat the D-partitioning for coordinates and predication - Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M - CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0); // TILE_N divides MMA_N - CUTE_STATIC_ASSERT(typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{})); + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N #if 0 if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { print("aC : "); print(accumulators.layout()); print("\n"); print("gC : "); print(gC.layout()); print("\n"); print("gD : "); print(gD.layout()); print("\n"); - print("sC : "); print(sC.layout()); print("\n"); + print("gBias : "); print(gBias.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); print("\n"); - print("tCsC : "); print(tCsC.layout()); print("\n"); - print("tCaC : "); print(tCaC.layout()); print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); print("\n"); print("gDt : "); print(gDt.layout()); print("\n"); - print("tDsC : "); print(tDsC.layout()); print("\n"); - print("tDrC : "); print(tDrC.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); print("\n"); - print("tDrD : "); print(tDrD.layout()); print("\n"); - print("tDgC : "); print(tDgC.layout()); print("\n"); - print("tDgD : "); print(tDgD.layout()); print("\n"); + print("tSR_rC : "); print(tSR_rC.layout()); print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); print("\n"); + print("gBiast : "); print(gBiast.layout()); print("\n"); + print("tSR_gBias : "); print(tSR_gBias.layout()); print("\n"); + print("tSR_rBias : "); print(tSR_rBias.layout()); print("\n"); } #endif + if constexpr (IsEpilogueBiasSupported) { + if (params.ptr_Bias) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); + Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); + Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); + + // Step 0. Copy Bias from GMEM to fragment + auto pred_fn = [&] (auto const&... coords) { return elem_less(tSR_cD_flt(coords...), take<0, 2>(residue_mnk)); }; + copy_if(pred_fn, tSR_gBias_flt, tSR_rBias_flt); + } + } + // For each tiling needed for SmemLayout to cover shape(gD) CUTLASS_PRAGMA_UNROLL - for (int step_m = 0; step_m < size<2>(cDt); ++step_m) - { + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { CUTLASS_PRAGMA_UNROLL - for (int step_n = 0; step_n < size<3>(cDt); ++step_n) - { + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { // Step 1. Copy to SMEM CUTLASS_PRAGMA_UNROLL - for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) { + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { CUTLASS_PRAGMA_UNROLL - for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { - int mma_m = step_m * size<1>(tCsC) + pipe_m; - int mma_n = step_n * size<2>(tCsC) + pipe_n; + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; - copy(tiled_r2s, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); } } @@ -281,59 +417,115 @@ class Epilogue { synchronize(); // Step 3. Copy from SMEM into a fragment - copy(tiled_s2r, tDsC, tDrC); + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); // Step 4. Wait for SMEM reads to complete synchronize(); - Tensor tDgDmn = tDgD(_,_,_,step_m,step_n); - Tensor tDcDmn = tDcD(_,_,_,step_m,step_n); + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if constexpr (IsEpilogueBiasSupported) { + Tensor tSR_rBiasmn = tSR_rBias(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i)); + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i)); + } + } + } - if (epilogue_op.is_source_needed()) { - // source is needed - Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tDgDmn); ++m) - { + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tDgDmn); ++n) - { + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { // Predication - if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && - get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) - { - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tDrC); ++i) { - tDrD(i,m,n) = epilogue_op(tDrC(i,m,n), tDgCmn(i,m,n)); + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } else { + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } } - // Step 6. Copy to GMEM - copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); } } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i), tSR_rC(i)); + } } - } - else { - // source is not needed, avoid load and lift compute + else { + // source is not needed, avoid load and lift compute - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tDrC); ++i) { - tDrD(i) = epilogue_op(tDrC(i)); + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } } CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tDgDmn); ++m) - { + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tDgDmn); ++n) - { + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { // Predication - if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && - get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) - { - // Step 6. Copy to GMEM - copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); } } } diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp new file mode 100644 index 0000000000..8a70370b21 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Ptr Array Epilogue Vectorized +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_, + class EpilogueScheduleType_ +> +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + using GmemTiledCopyC = TiledCopyS2R; + using GmemTiledCopyD = TiledCopyS2R; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + using TensorMapStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + stride_c = params.dC[l_coord]; + } + stride_d = params.dD[l_coord]; + } + else { + stride_c = params.dC; + stride_d = params.dD; + } + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sAcc to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sAcc, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); + print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); + print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); + print("\n"); + } +#endif + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; + + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + Tensor tSR_rCmn = make_tensor(shape(tSR_gCmn)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rCmn(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rD(i,m,n) = epilogue_op(tSR_rAcc(i,m,n), tSR_rCmn(i,m,n)); + } + // Step 7. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 87b6786721..56bdd84344 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -77,7 +77,8 @@ template < class CopyOpS2G_, class SmemLayoutAtomD_, class CopyOpR2S_, - class CopyAtomC_ + class CopyAtomC_, + class CopyOpR2R_ > class CollectiveEpilogue< Sm90PtrArrayTmaWarpSpecialized { public: // @@ -129,7 +131,7 @@ class CollectiveEpilogue< using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; using CopyAtomC = CopyAtomC_; - + using CopyOpR2R = CopyOpR2R_; using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2S; @@ -164,6 +166,9 @@ class CollectiveEpilogue< constexpr static bool is_im2col_C = cute::is_same_v; constexpr static bool is_im2col_D = cute::is_same_v; + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), @@ -233,7 +238,7 @@ class CollectiveEpilogue< FusionStorage thread; } tensors; - struct TensorMapStorage : cute::aligned_struct<128> { + struct TensorMapStorage : cute::aligned_struct<128, _0> { cute::TmaDescriptor smem_tensormap_C; cute::array smem_tensormap_D; } tensormaps; @@ -265,7 +270,7 @@ class CollectiveEpilogue< take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{})); - + using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), @@ -333,7 +338,6 @@ class CollectiveEpilogue< take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); - } typename Params::TMA_D tma_store_d; @@ -369,16 +373,18 @@ class CollectiveEpilogue< template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); auto descriptors_shape = cute::make_shape(sm_count, Int{}); constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); } template static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); } @@ -408,10 +414,10 @@ class CollectiveEpilogue< constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); } - + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); } - } + } else { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); } @@ -507,9 +513,9 @@ class CollectiveEpilogue< auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; static_assert(!is_im2col_D, "Do not support im2col"); - + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); - + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); @@ -542,12 +548,8 @@ class CollectiveEpilogue< // Predication for TMA load (one thread issues TMA load) bool issue_tma_load = cute::elect_one_sync(); - // Acquire the lock for the first stage - load_pipeline.producer_acquire(load_pipe_producer_state); - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - // Pre-loop fusion callback entry point - pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + pld_callbacks.begin(); LoadPipelineState prior_state = load_pipe_producer_state; @@ -560,9 +562,11 @@ class CollectiveEpilogue< if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { continue; } + // Acquire the lock for this stage constexpr uint16_t mcast_mask = 0; uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); // Loop fusion callback entry point @@ -589,7 +593,7 @@ class CollectiveEpilogue< pld_callbacks.end(); if (wait_until_load_finishes && did_load) { - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); } @@ -661,6 +665,7 @@ class CollectiveEpilogue< // Represent the full output tensor, slice to get the tile this CTA is responsible for Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) @@ -677,8 +682,27 @@ class CollectiveEpilogue< TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) @@ -733,6 +757,8 @@ class CollectiveEpilogue< CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); // Get the fusion callbacks for the consumer store warps constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ @@ -741,7 +767,7 @@ class CollectiveEpilogue< tile_coord_mnkl, tiled_mma, EpilogueTile{}, - tiled_r2s, + tiled_copy_partition_ref, cD, residue_cD, tRS_cD, @@ -774,7 +800,7 @@ class CollectiveEpilogue< // Sync requirements of smem reuse may preclude this optimization // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD int epi_m_prev = 0, epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); // The TMA store sequence for one subtile iteration auto tma_store_fn = [&] (int epi_m, int epi_n) { @@ -886,6 +912,16 @@ class CollectiveEpilogue< cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output needs register shuffling before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + // Copy tile from register to smem if constexpr (is_destination_supported) { copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); @@ -905,6 +941,7 @@ class CollectiveEpilogue< } // for epi_m } // for epi_n + if constexpr (DelayTmaStore) { // Issue TMA stores for the last subtile tma_store_fn(epi_m_prev, epi_n_prev); @@ -991,6 +1028,7 @@ class CollectiveEpilogue< } __syncwarp(); return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); + } TmaDescriptor* null_tma_desc = nullptr; return cute::make_tuple(null_tma_desc); @@ -1065,7 +1103,7 @@ class CollectiveEpilogue< prob_shape, prob_stride); } - } + } else if constexpr (is_destination_supported) { ElementD const* ptr_D = nullptr; @@ -1096,8 +1134,8 @@ class CollectiveEpilogue< ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch, int32_t warp_group_idx) { - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); @@ -1106,6 +1144,7 @@ class CollectiveEpilogue< tensormaps_replace_global_tensor_properties( shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); } + } } @@ -1117,6 +1156,7 @@ class CollectiveEpilogue< cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate, int32_t warp_group_idx = 0) { + // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { if constexpr (is_source_supported) { @@ -1136,7 +1176,7 @@ class CollectiveEpilogue< if constexpr (not cute::is_void_v) { cute::tma_descriptor_fence_acquire(tensormap); } - } + } else { cute::tma_descriptor_fence_acquire(tensormap); } diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 6d173bfe23..b2fa4e3573 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -75,7 +75,8 @@ template < class CopyOpS2G_, class SmemLayoutAtomD_, class CopyOpR2S_, - class CopyAtomC_ + class CopyAtomC_, + class CopyOpR2R_ > class CollectiveEpilogue< Sm90TmaWarpSpecialized, @@ -92,7 +93,8 @@ class CollectiveEpilogue< CopyOpS2G_, SmemLayoutAtomD_, CopyOpR2S_, - CopyAtomC_ + CopyAtomC_, + CopyOpR2R_, > { public: // @@ -113,6 +115,7 @@ class CollectiveEpilogue< using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2S; @@ -147,6 +150,9 @@ class CollectiveEpilogue< constexpr static bool is_im2col_C = cute::is_same_v; constexpr static bool is_im2col_D = cute::is_same_v; + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), @@ -454,12 +460,8 @@ class CollectiveEpilogue< // Predication for TMA load (one thread issues TMA load) bool issue_tma_load = cute::elect_one_sync(); - // Acquire the lock for the first stage - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - load_pipeline.producer_acquire(load_pipe_producer_state); - // Pre-loop fusion callback entry point - pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + pld_callbacks.begin(); CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { @@ -568,8 +570,27 @@ class CollectiveEpilogue< TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) @@ -581,7 +602,7 @@ class CollectiveEpilogue< // Allocate D registers Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); - Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) // Vectorized fragment view constexpr int FragmentSize = DispatchPolicy::FragmentSize; @@ -624,15 +645,17 @@ class CollectiveEpilogue< CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( problem_shape_mnkl, CtaTileMNK{}, tile_coord_mnkl, tiled_mma, EpilogueTile{}, - tiled_r2s, + tiled_copy_partition_ref, cD, residue_cD, tRS_cD, @@ -647,7 +670,7 @@ class CollectiveEpilogue< using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); constexpr bool IsDirectR2S = cute::is_same_v>; using RegisterElementD = cute::conditional_t; - Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) Tensor tRS_rCompute_frg = recast>(tRS_rCompute); // Thread synchronizer for previously issued waits or fences @@ -672,7 +695,7 @@ class CollectiveEpilogue< // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD [[maybe_unused]] int epi_m_prev = 0; [[maybe_unused]] int epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); // The TMA store sequence for one subtile iteration auto tma_store_fn = [&] (int epi_m, int epi_n) { @@ -784,6 +807,16 @@ class CollectiveEpilogue< cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output register transformation before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tRS_rD_frg); ++i) { tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp index b67c229c27..9749040081 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -62,7 +62,8 @@ template < class CopyOpS2G_, class SmemLayoutAtomD_, class CopyOpR2S_, - class CopyAtomC_ + class CopyAtomC_, + class CopyOpR2R_ > class Sm90EpilogueTmaWarpSpecializedBiasElementwise : public CollectiveEpilogue< @@ -80,7 +81,8 @@ class Sm90EpilogueTmaWarpSpecializedBiasElementwise CopyOpS2G_, SmemLayoutAtomD_, CopyOpR2S_, - CopyAtomC_ + CopyAtomC_, + CopyOpR2R_ > { private: using Impl = @@ -99,7 +101,8 @@ class Sm90EpilogueTmaWarpSpecializedBiasElementwise CopyOpS2G_, SmemLayoutAtomD_, CopyOpR2S_, - CopyAtomC_ + CopyAtomC_, + CopyOpR2R_ >; public: using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index e96f413445..f829a2ff5d 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -46,12 +46,13 @@ namespace cutlass::epilogue { ////////////////////////////////////////////////////////////////////////////// struct PtrArrayDefault {}; +struct EpilogueSimtVectorized {}; +struct EpiloguePtrArraySimtVectorized {}; struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; - struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 0bfacf34cc..3aed32710f 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -33,6 +33,7 @@ #include #include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -123,6 +124,19 @@ struct LinCombEltAct static constexpr bool IsEltActSupported = true; }; +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombTopKSoftmaxCol + : LinearCombination { +}; + // D = alpha * acc + beta * C + per-row bias template< @@ -131,7 +145,7 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombPerRowBias @@ -141,39 +155,39 @@ struct LinCombPerRowBias static constexpr bool IsPerRowBiasSupported = true; }; -// D = activation(alpha * acc + beta * C + per-row bias) +// D = alpha * acc + beta * C + per-column bias template< - template class ActivationFn_, class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > -struct LinCombPerRowBiasEltAct - : LinCombPerRowBias { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; +struct LinCombPerColBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; }; -// D = alpha * acc + beta * C + per-column bias +// D = activation(alpha * acc + beta * C + per-row bias) template< + template class ActivationFn_, class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > -struct LinCombPerColBias - : LinearCombination { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsPerColBiasSupported = true; +struct LinCombPerRowBiasEltAct + : LinCombPerRowBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; }; // D = activation(alpha * acc + beta * C + per-row bias) @@ -187,8 +201,8 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombPerRowBiasEltActAux @@ -208,8 +222,8 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / sizeof_bits_v, - int AlignmentScalar_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct PerRowLinCombPerRowBiasEltAct @@ -231,7 +245,7 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct ScaledLinCombPerRowBiasEltAct @@ -261,8 +275,8 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct ScaledLinCombPerRowBiasEltActAmaxAux @@ -288,7 +302,7 @@ template< class ElementAux_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombDeEltAct @@ -315,8 +329,8 @@ template< class ElementBias_ = ElementCompute_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombDeEltActDePerRowBias diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index ece5ac542e..e028846a4f 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -46,6 +46,8 @@ #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::epilogue::fusion { @@ -75,12 +77,12 @@ struct FusionCallbacks< CtaTileShapeMNK, EpilogueTile > : Sm90EVT, - Sm90ScalarBroadcast, + Sm90ScalarBroadcast>, Sm90AccFetch > { using Impl = Sm90EVT, - Sm90ScalarBroadcast, + Sm90ScalarBroadcast>, Sm90AccFetch >; using Operation = fusion::ScaledAcc; @@ -92,12 +94,15 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + // Conversion to the args expected by the visitor implementation // to_underlying_arguments will implicitly call this operator typename Impl::Arguments() const { return { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }; // end binary op @@ -120,10 +125,10 @@ template< > using Sm90LinearCombination = Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcast, // beta + Sm90ScalarBroadcast>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc - Sm90ScalarBroadcast, // alpha + Sm90ScalarBroadcast>, // alpha Sm90AccFetch // acc > >; @@ -158,13 +163,18 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + operator typename Impl::Arguments() const { return { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -189,10 +199,10 @@ template< > using Sm90LinearCombinationPtrArray = Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcastPtrArray>, // beta + Sm90ScalarBroadcastPtrArray>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc - Sm90ScalarBroadcastPtrArray>, // alpha + Sm90ScalarBroadcastPtrArray>, // alpha Sm90AccFetch // acc > >; @@ -236,8 +246,8 @@ struct FusionCallbacks< ElementScalar const* const* alpha_ptr_array = nullptr; ElementScalar const* const* beta_ptr_array = nullptr; - using StrideAlpha = Stride<_0,_0,int>; - using StrideBeta = Stride<_0,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; StrideAlpha dAlpha = {_0{}, _0{}, 0}; StrideBeta dBeta = {_0{}, _0{}, 0}; @@ -307,6 +317,11 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); @@ -314,10 +329,96 @@ struct FusionCallbacks< return { // unary op: activation(beta * C + (alpha * acc)) { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C), where beta and alpha can be vectors for each batch +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltActPtrArray = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltActPtrArray { + + using Impl = Sm90LinCombEltActPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -347,12 +448,12 @@ template< > using Sm90LinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta + Sm90ScalarBroadcast>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha + Sm90ScalarBroadcast>, // alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias > >; @@ -390,17 +491,22 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; operator typename Impl::Arguments() const { return { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -431,12 +537,12 @@ template< > using Sm90LinCombPerColBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta + Sm90ScalarBroadcast>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha + Sm90ScalarBroadcast>, // alpha Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias > >; @@ -474,17 +580,22 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_0,_1,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; operator typename Impl::Arguments() const { return { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -560,7 +671,12 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -571,10 +687,10 @@ struct FusionCallbacks< return { // unary op : activation(beta * C + (alpha * acc + bias)) { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -673,7 +789,12 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -689,10 +810,10 @@ struct FusionCallbacks< { // unary op : activation(store(beta * C + (alpha * acc + bias))) { // unary op : store(beta * C + (alpha * acc + bias)) { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -725,12 +846,12 @@ template< > using Sm90PerRowLinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, AlignmentScalar>, // beta + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // beta, dynamic scalar/vector broadcast Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, AlignmentScalar>, // alpha + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias > >; @@ -792,16 +913,16 @@ struct FusionCallbacks< >; struct Arguments { - using StrideAlpha = Stride<_1,_0,int>; - using StrideBeta = Stride<_1,_0,int>; + using StrideAlpha = Stride; + using StrideBeta = Stride; ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - StrideAlpha dAlpha = {}; - StrideBeta dBeta = {}; + StrideAlpha dAlpha = {bool(1), _0{}, 0}; + StrideBeta dBeta = {bool(1), _0{}, 0}; - using StrideBias = Stride<_1,_0,int>; + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -864,12 +985,12 @@ template< > using Sm90ScaledLinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90ScalarBroadcast, 2>, // scale_c * beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias > >; @@ -950,7 +1071,12 @@ struct FusionCallbacks< ElementScalar const* scale_c_ptr = nullptr; ElementScalar const* scale_d_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -962,13 +1088,15 @@ struct FusionCallbacks< { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{scale_c, beta}, - {scale_c_ptr, beta_ptr} + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{scale_a, scale_b, alpha}, - {scale_a_ptr, scale_b_ptr, alpha_ptr} + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias @@ -1184,7 +1312,12 @@ struct FusionCallbacks< ElementScalar scale_aux = ElementScalar(1); ElementScalar const* scale_aux_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -1213,13 +1346,15 @@ struct FusionCallbacks< Z_args = { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{scale_c, beta}, - {scale_c_ptr, beta_ptr} + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{scale_a, scale_b, alpha}, - {scale_a_ptr, scale_b_ptr, alpha_ptr} + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias @@ -1269,13 +1404,15 @@ struct FusionCallbacks< { // unary op : activation(Z) { // unary op : store(Z) { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{scale_c, beta}, - {scale_c_ptr, beta_ptr} + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{scale_a, scale_b, alpha}, - {scale_a_ptr, scale_b_ptr, alpha_ptr} + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias @@ -1377,6 +1514,11 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); @@ -1388,10 +1530,10 @@ struct FusionCallbacks< return { // binary op : activation(beta * C + (alpha * acc), aux) { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -1430,7 +1572,7 @@ template< using Sm90LinCombDeEltActDePerRowBias = Sm90EVT, // Identity for final conversion Sm90EVT, AlignmentBias>, + ElementBias, ElementCompute, RoundStyle, Stride<_1,_0,int64_t>, AlignmentBias>, Sm90LinCombDeEltAct > @@ -1490,6 +1632,11 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); @@ -1497,7 +1644,7 @@ struct FusionCallbacks< ElementAux const* aux_ptr = nullptr; StrideAux dAux = {}; - using StrideBias = Stride<_1,_0,int>; + using StrideBias = Stride<_1,_0,int64_t>; ElementBias* dbias_ptr = nullptr; StrideBias dDbias = {}; @@ -1507,10 +1654,10 @@ struct FusionCallbacks< { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) { // binary op : activation(beta * C + (alpha * acc), aux) { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -1532,6 +1679,78 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombTopKSoftmaxCol = + Sm90EVT, // softmax(top_k(beta * C + (alpha * acc))) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int TopK, + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombTopKSoftmaxCol, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombTopKSoftmaxCol { + + using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombTopKSoftmaxCol; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace detail { template > struct get_element_aux { diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 2ae10a688a..131d0ba5b9 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -263,8 +263,16 @@ struct Sm90TreeVisitor< CUTLASS_DEVICE bool is_producer_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); auto const& added_op = get<2>(Impl::ops); - return is_C_load_needed() || added_op.is_producer_load_needed(); + if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { + return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || + is_C_load_needed() || + added_op.is_producer_load_needed(); + } + else { + return is_C_load_needed() || added_op.is_producer_load_needed(); + } } CUTLASS_DEVICE bool @@ -296,7 +304,7 @@ struct Sm90TreeVisitor< Array frg_I = convert_Z(frg_added); - if (is_C_load_needed) { + if constexpr (!is_void_v) { Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); @@ -323,8 +331,12 @@ struct Sm90TreeVisitor< CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + bool is_C_load_needed = this->is_C_load_needed(); + if (not is_C_load_needed) { + cute::clear(args.tCrC); + } return ConsumerStoreCallbacks( - is_C_load_needed(), std::move(callbacks_tuple)); + is_C_load_needed, std::move(callbacks_tuple)); } }; @@ -497,7 +509,18 @@ struct Sm90TreeVisitor< else { frg_compute[i] = relu(frg_compute[i]); } - frg_aux[i] = frg_compute[i] == pre_relu; + if constexpr (cute::is_same_v) { + uint32_t aux; + asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else if constexpr (cute::is_same_v) { + uint32_t aux; + cutlass::half_t compute = frg_compute[i]; + asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else { + frg_aux[i] = frg_compute[i] == pre_relu; + } } static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index aedacb552e..a22bed4e0d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -378,6 +378,174 @@ struct Sm90AuxLoad { } }; +template < + class Element, + class EpilogueTile, // Unused + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpS2R, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + 0, EpilogueTile, Element, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorG2R, + class RTensor, + class CTensorG2R, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorG2R&& tC_gAux, + RTensor&& tC_rAux, + CTensorG2R&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorG2R tC_gAux; + RTensor tC_rAux; + CTensorG2R tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + auto pred_fn = [&] (auto const&... coords) { + return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn, tC_gAux_vec, tC_rAux_vec); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tC_rAux)(epi_v); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Broadcast Load Operations @@ -388,11 +556,12 @@ struct Sm90AuxLoad { // Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors template< class Element, - class StrideMNL = Stride<_0,_0,_0>, + class StrideMNL_ = Stride<_0,_0,_0>, int BroadcastCount = 1, template class ReductionFn = multiplies > struct Sm90ScalarBroadcast { + using StrideMNL = StrideMNL_; static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); @@ -401,7 +570,7 @@ struct Sm90ScalarBroadcast { struct Arguments { Element scalars[BroadcastCount] = {}; Element const* scalar_ptrs[BroadcastCount] = {}; - StrideMNL dScalar = {}; + StrideMNL dScalar[BroadcastCount] = {}; }; using Params = Arguments; @@ -444,7 +613,21 @@ struct Sm90ScalarBroadcast { // This must be called after update_scalar is called CUTLASS_DEVICE bool is_zero() const { - return scalar == Element(0); + if (get<2>(params_ptr->dScalar[0]) == 0) { + // Only 1 batch + return scalar == Element(0); + } + else { + // multiple batch + if (valid_scalar == false) { + // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. + return params_ptr->scalar_ptrs[0] == nullptr; + } + else { + // Check whether each batch is ZERO or not. + return scalar == Element(0); + } + } } CUTLASS_HOST_DEVICE @@ -454,19 +637,20 @@ struct Sm90ScalarBroadcast { Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { // Get the scalar for non-batched broadcast - if (get<2>(params_ptr->dScalar) == 0) { + if (size<2>(params_ptr->dScalar[0]) == 0) { update_scalar(); } } Element scalar; + bool valid_scalar = false; Params const* params_ptr; template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar) != 0) { + if (size<2>(params_ptr->dScalar[0]) != 0) { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -500,7 +684,7 @@ struct Sm90ScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar) != 0) { + if (get<2>(params_ptr->dScalar[0]) != 0) { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -511,11 +695,12 @@ struct Sm90ScalarBroadcast { private: CUTLASS_DEVICE void update_scalar(int l_coord = 0) { - int l_offset = l_coord * size<2>(params_ptr->dScalar); + valid_scalar = true; + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); if (params_ptr->scalar_ptrs[0] != nullptr) { scalar = params_ptr->scalar_ptrs[0][l_offset]; - } + } else { // batch stride is ignored for nullptr fallback scalar = params_ptr->scalars[0]; @@ -526,8 +711,10 @@ struct Sm90ScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 1; i < BroadcastCount; ++i) { if (params_ptr->scalar_ptrs[i] != nullptr) { - scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][l_offset]); - } else { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { // batch stride is ignored for nullptr fallback scalar = reduction_fn(scalar, params_ptr->scalars[i]); } @@ -538,8 +725,8 @@ struct Sm90ScalarBroadcast { CUTLASS_DEVICE void update_scalar(cute::tuple) { // Only support multiple L-modes with fully-broadcast scalar - static_assert(cute::is_same_v>); scalar = params_ptr->scalars[0]; + valid_scalar = true; } }; @@ -706,6 +893,7 @@ struct Sm90ScalarBroadcastPtrArray { } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { @@ -722,32 +910,40 @@ compute_row_broadcast_stages() { template< int Stages, class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, + class ElementInput, + class ElementCompute = ElementInput, + class StrideMNL_ = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90RowBroadcast { - static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + using StrideMNL = StrideMNL_; + static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); struct SharedStorage { - array_aligned(CtaTileShapeMNK{})> smem; + array_aligned(CtaTileShapeMNK{})> smem; }; struct Arguments { - Element const* ptr_row = nullptr; - Element null_default = Element(0); + ElementInput const* ptr_row = nullptr; + ElementInput null_default = ElementInput(0); StrideMNL dRow = {}; }; - using Params = Arguments; + struct Params { + ElementInput const* ptr_row = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dRow = {}; + }; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; + return {args.ptr_row, ElementCompute(args.null_default), args.dRow}; } template @@ -774,11 +970,22 @@ struct Sm90RowBroadcast { CUTLASS_HOST_DEVICE Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params) - , smem(const_cast(shared_storage.smem.data())) { } + : params(params), is_zero_(false), + smem(const_cast(shared_storage.smem.data())) { + auto const& [stride_M, stride_N, stride_L] = params.dRow; + // Nullptr default + if (EnableNullptr && params.ptr_row == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { + is_zero_ = params.ptr_row[0] == ElementInput(0); + } + } Params params; - Element *smem = nullptr; + bool is_zero_ = false; + ElementInput *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { @@ -792,7 +999,7 @@ struct Sm90RowBroadcast { CUTLASS_DEVICE bool is_zero() const { - return (params.ptr_row == nullptr && params.null_default == Element(0)); + return is_zero_; } template @@ -801,24 +1008,27 @@ struct Sm90RowBroadcast { return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks( GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, - CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + Residue residue_cRow_, ThrNum thr_num_, Params const& params_) : tGS_gRow(tGS_gRow_) , tGS_sRow(tGS_sRow_) , tGS_cRow(tGS_cRow_) , tiled_G2S(tiled_g2s_) , tSR_sRow(tSR_sRow_) , tSR_rRow(tSR_rRow_) - , tCcRow(tCcRow_) - , residue_tCcRow(residue_tCcRow_) + , residue_cRow(residue_cRow_) , params(params_) - , is_nullptr(EnableNullptr && params_.ptr_row == nullptr) {} + , is_nullptr(EnableNullptr && params_.ptr_row == nullptr) { + if (is_nullptr) { + fill(tSR_rRow, params.null_default); + } + } GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) @@ -828,35 +1038,31 @@ struct Sm90RowBroadcast { SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tCcRow; // (m, n) + Residue residue_cRow; // (m, n) ThrNum thr_num; Params const& params; bool is_nullptr; CUTLASS_DEVICE void begin() { - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - fill(tSR_rRow, params.null_default); - return; - } + if (is_nullptr) { + return; } auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride()); for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { continue; // OOB of SMEM, } - if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + if (elem_less(tGS_cRow_flt(i), residue_cRow)) { tGS_sRow_flt(i) = tGS_gRow_flt(i); } else { - tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + tGS_sRow_flt(i) = ElementInput(0); // Set to Zero when OOB so LDS can be issued without any preds. } } synchronize(); @@ -864,18 +1070,28 @@ struct Sm90RowBroadcast { CUTLASS_DEVICE void begin_loop(int epi_m, int epi_n) { - if (epi_m == 0) { // Assumes M-major subtile loop - if (is_nullptr) return; // Do not issue LDS when bias is nullptr + if (epi_m == 0 and not is_nullptr) { // Assumes M-major subtile loop Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); - copy(tSR_sRow_flt, tSR_rRow_flt); + Tensor tSR_rRow_flt = make_tensor_like(tSR_sRow_flt); + copy_aligned(tSR_sRow_flt, tSR_rRow_flt); + + constexpr int FrgSize = size(tSR_rRow_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tSR_rRow_input_frg = recast(coalesce(tSR_rRow_flt)); + Tensor tSR_rRow_compute_frg = recast(filter(tSR_rRow)); + ConvertInput convert_input{}; + + tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{})); } } template - CUTLASS_DEVICE Array + CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_row; + Array frg_row; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { @@ -896,12 +1112,30 @@ struct Sm90RowBroadcast { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + auto layout_N = [&] () { + auto shape_N = get<1>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_N = repeat_like(shape_N, int(0)); + if (get<1>(params.dRow) == bool(1)) { + stride_N = transform_leaf(compact_major(shape_N), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_N, stride_N); + } + else { + return make_layout(shape_N); + } + }(); + + auto layout_M = make_layout(M, repeat_like(M, _0{})); + auto layout_L = make_layout(L, get<2>(params.dRow)); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L)); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem - auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, Layout< Shape<_1, ThreadCount>, Stride<_0, _1>>{}, Layout<_1>{}); @@ -910,20 +1144,18 @@ struct Sm90RowBroadcast { Tensor tGS_sRow = thr_g2s.partition_D(sRow); //// G2S: Coord - auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); - Tensor tGS_cRow = thr_g2s.partition_S(cRow); + Tensor tGS_cRow = thr_g2s.partition_S(args.cD); //// S2R: Smem to Reg Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) - return ConsumerStoreCallbacks( + return ConsumerStoreCallbacks( tGS_gRow, tGS_sRow, tGS_cRow, tiled_g2s, tSR_sRow, tSR_rRow, - args.tCcD, args.residue_cD, ThreadCount{}, params); @@ -936,31 +1168,39 @@ struct Sm90RowBroadcast { template< int Stages, class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v, + class ElementInput, + class ElementCompute = ElementInput, + class StrideMNL_ = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90ColBroadcast { - static_assert(Stages == 0, "Column broadcast doesn't support smem usage"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); + using StrideMNL = StrideMNL_; + static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast); // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem struct SharedStorage { }; struct Arguments { - Element const* ptr_col = nullptr; - Element null_default = Element(0); + ElementInput const* ptr_col = nullptr; + ElementInput null_default = ElementInput(0); StrideMNL dCol = {}; }; - using Params = Arguments; + struct Params { + ElementInput const* ptr_col = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dCol = {}; + }; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; + return {args.ptr_col, ElementCompute(args.null_default), args.dCol}; } template @@ -994,7 +1234,7 @@ struct Sm90ColBroadcast { CUTLASS_DEVICE bool is_zero() const { - return (params.ptr_col == nullptr && params.null_default == Element(0)); + return is_zero_; } CUTLASS_HOST_DEVICE @@ -1002,9 +1242,20 @@ struct Sm90ColBroadcast { CUTLASS_HOST_DEVICE Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params) { } + : params(params), is_zero_(false) { + auto const& [stride_M, stride_N, stride_L] = params.dCol; + // Nullptr default + if (EnableNullptr && params.ptr_col == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { + is_zero_ = params.ptr_col[0] == ElementInput(0); + } + } Params params; + bool is_zero_; template CUTLASS_DEVICE auto @@ -1015,12 +1266,16 @@ struct Sm90ColBroadcast { template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, CTensor tCcCol, ThrResidue residue_tCcCol, Params const& params) - : tCgCol(cute::forward(tCgCol)), - tCrCol(cute::forward(tCrCol)), - tCcCol(tCcCol), - residue_tCcCol(residue_tCcCol), - params(params) {} + ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_) + : tCgCol(tCgCol_), + tCrCol(tCrCol_), + tCcCol(tCcCol_), + residue_tCcCol(residue_tCcCol_), + params(params_) { + if (EnableNullptr && params.ptr_col == nullptr) { + fill(tCrCol, params.null_default); + } + } GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) @@ -1030,23 +1285,20 @@ struct Sm90ColBroadcast { CUTLASS_DEVICE void begin() { - if constexpr (EnableNullptr) { - if (params.ptr_col == nullptr) { - fill(tCrCol, params.null_default); - return; - } + if (EnableNullptr && params.ptr_col == nullptr) { + return; } // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) Tensor tCgCol_flt = filter_zeros(tCgCol); - Tensor tCrCol_flt = filter_zeros(tCrCol); - Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + Tensor tCrCol_flt = make_tensor_like(filter_zeros(tCrCol)); + Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride()); constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; constexpr int V = cute::min(Alignment, size(MCL)); if constexpr (V > 1) { - using VecType = uint_bit_t>; + using VecType = uint_bit_t>; Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); @@ -1057,12 +1309,23 @@ struct Sm90ColBroadcast { auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_flt(coords...), residue_tCcCol); }; copy_if(pred_fn, tCgCol_flt, tCrCol_flt); } + + constexpr int FrgSize = size(tCrCol_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tCrCol_input_frg = recast(coalesce(tCrCol_flt)); + Tensor tCrCol_compute_frg = recast(filter(tCrCol)); + ConvertInput convert_input{}; + + tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{})); } template - CUTLASS_DEVICE Array + CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_col; + Array frg_col; Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL @@ -1083,13 +1346,34 @@ struct Sm90ColBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + auto layout_M = [&] () { + auto shape_M = get<0>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_M = repeat_like(shape_M, int(0)); + if (get<0>(params.dCol) == bool(1)) { + stride_M = transform_leaf(compact_major(shape_M), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_M, stride_M); + } + else { + return make_layout(shape_M); + } + }(); + + auto layout_N = make_layout(N, repeat_like(N, _0{})); + auto layout_L = make_layout(L, get<2>(params.dCol)); + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L)); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks( - cute::move(tCgCol), cute::move(tCrCol), args.tCcD, args.residue_tCcD, params); + Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L)); + Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params); } }; @@ -1110,6 +1394,20 @@ template < using Sm90MatrixBroadcast = Sm90AuxLoad; +namespace detail { + +template +struct IsScalarBroadcast { + static constexpr bool value = false; +}; + +template +struct IsScalarBroadcast(typename Operation::StrideMNL{})), Stride<_0,_0>>>> { + static constexpr bool value = true; +}; + +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index ae7b42b2bd..060f8d1594 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -286,6 +286,185 @@ struct Sm90AuxStore { } }; +template < + class Element, + class EpilogueTile, // Unused + FloatRoundStyle RoundStyle, + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpR2S, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxStore< + 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorR2G, + class RTensor, + class CTensorR2G, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensorR2G&& tC_gAux, + RTensor&& tC_rAux, + CTensorR2G&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorR2G tC_gAux; + RTensor tC_rAux; + CTensorR2G tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + return; + } + } + + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + auto pred_fn = [&] (auto const&... coords) { + return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn, tC_rAux_vec, tC_gAux_vec); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Reduction Store Operations @@ -304,10 +483,8 @@ template < > struct Sm90ScalarReduction { private: - static_assert( - (cute::is_same_v>) || // scalar reduction, e.g. tensor max element - (cute::is_same_v>) || // batched scalar reduction, e.g. per-batch max element - (cute::is_same_v>)); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); static constexpr bool IsAtomic = is_atomic>::value; static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); @@ -344,6 +521,7 @@ struct Sm90ScalarReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + #if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); @@ -351,6 +529,7 @@ struct Sm90ScalarReduction { return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); } } + #endif return cutlass::Status::kSuccess; } @@ -480,15 +659,18 @@ template < // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput bool FinalReduction = true, // False means skip OOB predication if OOB inputs are known to be the reduction identity - bool VisitCheckOOB = true + bool VisitCheckOOB = true, + // Indicate the parameter order when calling RegReduceFn + // Seq length equals the number of RegReduceFn parameters + // No.0 represents tCrRow; No.1 and subsequent numbers sequentially represent frg_inputs in `visit` + class RegReduceSeq = cute::seq<0, 1> > struct Sm90RowReduction { private: static_assert(Stages == 0, "Smem usage not supported yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector reduction, e.g. per-col sum over all batches - (cute::is_same_v>)); // batched row vector reduction, e.g. per-col sum per batch + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); static constexpr bool IsAtomic = is_atomic>::value; static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); @@ -567,6 +749,7 @@ struct Sm90RowReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { +#if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); @@ -575,7 +758,9 @@ struct Sm90RowReduction { } return Status::kSuccess; } - else if constexpr (FinalReduction) { + else +#endif + if constexpr (FinalReduction) { auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); @@ -626,14 +811,13 @@ struct Sm90RowReduction { Params const& params; bool do_final_reduction = false; - - template + template CUTLASS_DEVICE auto visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { + Array const&... frg_inputs) { if constexpr (EnableNullptr) { if (params.ptr_row == nullptr) { - return frg_input; + return cute::get<0>(cute::make_tuple(frg_inputs...)); } } @@ -643,21 +827,50 @@ struct Sm90RowReduction { Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); - using ConvertInput = NumericArrayConverter; - using ReduceInput = RegReduceFn; - ConvertInput convert_input{}; - ReduceInput reduce_input{}; + if constexpr (VisitCheckOOB) { + using ReduceInput = RegReduceFn; + ReduceInput reduce_input{}; - Array frg_I = convert_input(frg_input); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (!VisitCheckOOB || elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { - ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); - tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + return ElementCompute(frg_input[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } } } + else { + constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); + using ReduceInput = RegReduceFn>; + ReduceInput reduce_input{}; + Tensor tCrRow_mn_frg = recast>(tCrRow_mn); - return frg_input; + constexpr int RegFragArraySize = FragmentSize / RegFragSize; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < RegFragArraySize; ++i) { + Array& tCrRow_vmn_frg = tCrRow_mn_frg(epi_v * RegFragArraySize + i); + tCrRow_vmn_frg = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + using RegFragArr = Array, RegFragArraySize>; + ConvertInput convert_input{}; + return convert_input(reinterpret_cast(frg_input)[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn_frg, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } + } + return cute::get<0>(cute::make_tuple(frg_inputs...)); } template @@ -683,23 +896,70 @@ struct Sm90RowReduction { return; } + int lane_m = get<0>(lane_mn); + [[maybe_unused]] bool is_reduced_lane = lane_m == 0; + // // 1. Warp shuffle reduction // using FragmentShuffle = Array; + Tensor tCrRow_frg = recast(filter(tCrRow)); using ReduceShuffle = ShuffleReduceFn; ReduceShuffle reduce_shuffle{}; - Tensor tCrRow_frg = recast(filter(tCrRow)); - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + + auto FrgSizePerLaneM = size(tCrRow_frg) / size<0>(lane_layout_MN); + constexpr bool SwapShuffle = FrgSizePerLaneM > 0; + + // + // Swap Shuffle + // + // The normal way to reduction among threads: + // use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. + // After each step of reduction, a half of threads won't work in the following steps. + // That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). + // + // To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, + // we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. + // After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. + // We can recursively do this until the problem size is 1. + // + if constexpr (SwapShuffle) { // for a NxN matrix to be reduced among N threads as a 1XN vectors + Tensor tCrRow_frg_ = logical_divide(tCrRow_frg, FrgSizePerLaneM); // (FrgSizePerLaneM, M) + CUTLASS_PRAGMA_UNROLL + for (int m = size<1>(tCrRow_frg_) / 2; m > 0; m /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < m; ++r) { + auto frg_A = tCrRow_frg_(_,r); + auto frg_B = tCrRow_frg_(_,r + m); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size(frg_A); ++v) { + // Step1: swap + if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second + swap(frg_A(v), frg_B(v)); + } + + // Step2: shuffle + uint64_t frg_shfl = reinterpret_cast(frg_A(v)); + // each half of threads get a half of data from the other half of threads + frg_shfl = __shfl_xor_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(m, _0{})); + + // Step3: reduction + frg_A(v) = reduce_shuffle(frg_B(v), reinterpret_cast(frg_shfl)); + } + } + } + } + else { CUTLASS_PRAGMA_UNROLL - for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { - uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); - frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); - tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); + tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + } } } - bool is_reduced_lane = get<0>(lane_mn) == 0; // // 2. Atomic reduction @@ -708,6 +968,7 @@ struct Sm90RowReduction { // Filter so we don't issue redunant copies over stride-0 modes Tensor tCrRow_flt = filter_zeros(tCrRow); Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); Tensor tCgRow_flt = filter_zeros(tCgRow); @@ -717,11 +978,23 @@ struct Sm90RowReduction { ConvertOutput convert_output{}; ReduceOutput reduce_output{}; - if (is_reduced_lane) { + if constexpr (SwapShuffle) { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrRow_flt); ++i) { - if (elem_less(tCcRow_flt(i), residue_tCcRow)) { - reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + for (int i = 0; i < FltFrgSizePerLaneM; ++i) { + int idx = lane_m * FltFrgSizePerLaneM + i; + // Only care about OOB for N mode + if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) { + reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i))); + } + } + } + else { + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrRow_flt); ++i) { + if (elem_less(tCcRow_flt(i), residue_tCcRow)) { + reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + } } } } @@ -735,10 +1008,21 @@ struct Sm90RowReduction { // Dump warp reduction to gmem workspace using ElementGmem = cute::conditional_t; Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCgBuf_flt = recast(filter(tCgBuf)); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCgBuf_flt_ = logical_divide(tCgBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCgBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + } } sync_fn(); } @@ -755,10 +1039,21 @@ struct Sm90RowReduction { // Dump warp reduction to smem workspace Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - // Filter so we don't issue redunant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), filter(tCsBuf)); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCsBuf_flt = filter(tCsBuf); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCsBuf_flt_ = logical_divide(tCsBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCsBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + // Filter so we don't issue redunant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrRow), filter(tCsBuf)); + } } sync_fn(); @@ -772,25 +1067,30 @@ struct Sm90RowReduction { Tensor sBuf_vec = recast(filter_zeros(sBuf)); constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; - // Do the threadblock smem reduction - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = size<0>(warp_layout_MN) / 2; reduction_rows > 1; reduction_rows /= 2) { - int FragsPerReduction = reduction_rows * FragsPerRow; - CUTLASS_PRAGMA_NO_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); - sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); - } - sync_fn(); - } + constexpr int RowNum = decltype(size<0>(warp_layout_MN))::value; + using FragmentSmemArray = Array; - // Do final smem reduction and dump to gmem workspace + // Do the threadblock smem reduction using VectorGmem = cute::conditional_t; Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); - CUTLASS_PRAGMA_NO_UNROLL + CUTLASS_PRAGMA_UNROLL for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerRow)); - gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + FragmentSmemArray frg_smem; + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = 0; reduction_rows < RowNum; ++reduction_rows) { + int FragsCurrRows = reduction_rows * FragsPerRow; + frg_smem[reduction_rows] = sBuf_frg(FragsCurrRows + frg_idx); + } + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = RowNum / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int row_idx = 0; row_idx < reduction_rows; ++row_idx) { + frg_smem[row_idx] = reduce_smem(frg_smem[row_idx], frg_smem[row_idx + reduction_rows]); + } + } + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem[0]); } sync_fn(); } @@ -959,9 +1259,8 @@ struct Sm90ColReduction { private: static_assert(Stages == 0, "Smem usage not supported yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // col vector reduction, e.g. per-row sum over all batches - (cute::is_same_v>)); // batched col vector reduction, e.g. per-row sum per batch + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); static constexpr bool IsAtomic = is_atomic>::value; static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); @@ -1042,6 +1341,7 @@ struct Sm90ColReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { +#if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); @@ -1050,7 +1350,9 @@ struct Sm90ColReduction { } return Status::kSuccess; } - else if constexpr (FinalReduction) { + else +#endif + if constexpr (FinalReduction) { auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 843640127d..4f7d99fa32 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -170,7 +170,7 @@ struct ConsumerStoreArgs { Residue residue_cD; ThrCoordTensor tCcD; ThrResidue residue_tCcD; - ThrSrcTensor const& tCrC; + ThrSrcTensor & tCrC; int thread_idx; CUTLASS_DEVICE @@ -185,7 +185,7 @@ struct ConsumerStoreArgs { Residue residue_cD, ThrCoordTensor tCcD, ThrResidue residue_tCcD, - ThrSrcTensor const& tCrC, + ThrSrcTensor & tCrC, int thread_idx) : problem_shape_mnkl(problem_shape_mnkl), tile_shape_mnk(tile_shape_mnk), @@ -361,14 +361,12 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables CallbacksTuple callbacks_tuple; - // Before entry of the subtile load loop. Bulk copies usually performed here. - // Upon entry the producer_acquire of the first subtile lock has completed. - // full_mbarrier_ptr is the corresponding barrier for the subsequent producer_commit arrival + // Before entry of the subtile load loop CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + begin() { for_each(callbacks_tuple, [&] (auto& callbacks) { - callbacks.begin(full_mbarrier_ptr, load_iteration, issue_tma_load); + callbacks.begin(); } ); } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp new file mode 100644 index 0000000000..53c0dce8ba --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -0,0 +1,759 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Top-K + Softmax reduction across columns +// Performs a reduction of top-K values across N, and finally performs a softmax on them, +// and sets values not in the top-K to 0. +// +// Assumptions: +// 1. CTA_N >= N (single tile across N, the mode which is reduced) +// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one +// epilogue tile at a time.) +// 3. Top-K value is either 2 or 4. +// + +namespace detail { + +// Implementations for add to sorted list and merging sorted lists, +// with fast paths for lists of size 2 and 4 (Top-2 and Top-4). +// Generic implementations may result in greater register use and branching, +// and should be avoided. +// Fast paths for Top-2 and Top-4 are written in inline PTX directly. + +CUTLASS_DEVICE +Array top_2_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx, %3, %4;\n" + " setp.gtu.f32 p, %2, %4;\n" + " selp.f32 %1, mx, %2, p;\n" + " selp.f32 %0, %2, %4, p;\n" + "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_2_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .v2 .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx.x, %3, %4;\n" // max(a1, b0) + " max.f32 mx.y, %2, %5;\n" // max(a0, b1) + " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 + " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) + " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 + "}\n" : "=f"(out[0]), "=f"(out[1]) : + "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" // max(a3, b) + " .reg .pred p0;\n" // a0 > b + " .reg .pred p1;\n" // a1 > b + " .reg .pred p2;\n" // a2 > b + " max.f32 mx, %7, %8;\n" // max(a3, b) + " setp.gtu.f32 p0, %4, %8;\n" // a0 > b + " setp.gtu.f32 p1, %5, %8;\n" // a1 > b + " setp.gtu.f32 p2, %6, %8;\n" // a2 > b + " selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2 + " selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b + " selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1 + " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b + " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 + " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mxa0b1;\n" // max(a0, b1) + " .reg .f32 mxa1b0;\n" // max(a1, b0) + + " .reg .f32 mxa2b0;\n" // max(a2, b0) + " .reg .f32 mxa1b1;\n" // max(a1, b1) + " .reg .f32 mxa0b2;\n" // max(a1, b1) + + " .reg .f32 mxa1b2;\n" // max(a1, b2) + " .reg .f32 mxa2b1;\n" // max(a2, b1) + " max.f32 mxa1b2, %5, %10;\n" + " max.f32 mxa2b1, %6, %9;\n" + + " .reg .f32 mxa3b0;\n" // max(a1, b2) + " .reg .f32 mxa0b3;\n" // max(a2, b1) + " max.f32 mxa3b0, %7, %8;\n" + " max.f32 mxa0b3, %4, %11;\n" + + " .reg .pred pa0b0;\n" // a0 > b0 + " .reg .pred pa1b0;\n" // a1 > b0 + " .reg .pred pa2b0;\n" // a2 > b0 + " .reg .pred pa0b1;\n" // a0 > b1 + " .reg .pred pa1b1;\n" // a1 > b1 + " .reg .pred pa0b2;\n" // a0 > b2 + " .reg .pred pb2a0;\n" // b1 > a0 + " .reg .pred pb1a0;\n" // b1 > a0 + + " setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0 + " setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0 + " setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0 + " setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1 + " setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1 + " setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2 + + " not.pred pb2a0, pa0b2;\n" + " not.pred pb1a0, pa0b1;\n" + + " selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0) + " selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1) + + " selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1) + " selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0) + " selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2) + + // a0 + " selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0 + + // a1 + " selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1) + + // a2 + " mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case + " selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0 + " selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0 + + // a3 + " mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases + " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case + " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 + " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), + "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); + return out; +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +void add_element_to_desc_sorted_array(cutlass::Array& a, Element b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce_scalar(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce_scalar(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] <= b) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] and b[0] are the largest elements in a[] and b[].) +template +CUTLASS_DEVICE +void merge_desc_sorted_arrays(cutlass::Array& a, const cutlass::Array& b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + int j = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] <= b[j]) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b[j]; + ++j; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +Element topk_logsumexp(cutlass::Array a) { + // Do one less `exp`, because we know what its result will be. + // Assume x is a set of `x_i`s, and `x_m` is the maximum of that set. + // logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m)) + // Compute m + log(1 + sum_{i != m}(x_i - x_m)) + Element sum = Element(1.0); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N; ++i) { + sum += fast_exp(a[i] - a[0]); + } + return a[0] + fast_log(sum); +} + +CUTLASS_DEVICE +float fast_masked_softmax(float value, float minimum, float logsumexp) { + float new_value; + asm volatile( + "{\n" + " .reg .pred p0;\n" + // value >= minimum + " setp.geu.f32 p0, %1, %2;\n" + + " .reg .f32 x_lse;\n" + " .reg .f32 %%f<11>;\n" + " .reg .b32 %%r<3>;\n" + + // x_lse = value - minimum + " sub.rn.f32 x_lse, %1, %3;\n" + + // exp(x_lse) + // The following is derived from a ptx dump of expf. + // exp requires a base conversion from exp2. + " fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n" + " cvt.sat.f32.f32 %%f2, %%f1;\n" + " fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n" + " add.f32 %%f4, %%f3, 0fCB40007F;\n" + " neg.f32 %%f5, %%f4;\n" + " fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n" + " fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n" + " mov.b32 %%r1, %%f3;\n" + " shl.b32 %%r2, %%r1, 23;\n" + " mov.b32 %%f8, %%r2;\n" + " ex2.approx.ftz.f32 %%f9, %%f7;\n" + " mul.f32 %%f10, %%f9, %%f8;\n" + + // Mask or softmax + " selp.f32 %0, %%f10, 0f00000000, p0;\n" + "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); + return new_value; +} + +template +CUTLASS_DEVICE +Element masked_softmax(Element value, Element minimum, Element logsumexp) { + if constexpr (is_same_v) { + // Inline PTX implementation + // Significantly reduces register requirements + return fast_masked_softmax(value, minimum, logsumexp); + } + else { + return value < minimum ? Element(0.0) : fast_exp(value - logsumexp); + } +} + +} // namespace detail + +template < + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + int Alignment = 128 / sizeof_bits_v, + bool UseButterflyReduce = true +> +struct Sm90TopKSoftmaxColReduction { +private: + static_assert(is_same_v, "Fused Top-K + Softmax reduction requires FP32 accumulation."); + static_assert(TopK == 2 || TopK == 4, "Fused Top-K + Softmax reduction only supports K=2 and K=4."); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + // Reduction tensors + // We have two tensors for this EVT node: a reduction tensor and a tensor holding + // final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax + // require different reductions, but those luckily overlap. Top-K obviously needs at least + // two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log + // of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the + // maximum of all x_i elements. Since safe softmax for any element x_i is computed as + // softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max)) + // we can track logsumexp instead of tracking two variables (sum of exps and the max). + // In addition, subtracting logsumexp from any element and taking its exp is equivalent to + // computing its softmax. + // + // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the + // way at all, because any element not in the top-K is going to be masked out and set to 0. + // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and + // keep it, and the smallest element in the top-K for masking out non-top-K elements. + // + // This means that our final reduction result will always be 2 elements, regardless of the value + // of K: minimum of top-K, and logsumexp. + // + // For each reduction tensor, we define a new struct for readability. + + struct ReductionResult { + ElementCompute min_; + ElementCompute logsumexp_; + + CUTLASS_DEVICE + ReductionResult() { } + + CUTLASS_DEVICE + ReductionResult(ElementCompute min, ElementCompute logsumexp): + logsumexp_(logsumexp), min_(min) { } + + // Warp shuffle broadcast + CUTLASS_DEVICE + void shuffle_up_sync(uint32_t delta, int lane_id) { + static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); + uint64_t r = reinterpret_cast(*this); + r = __shfl_up_sync(0xFFFFFFFF, r, delta); + *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; + } + }; + + struct TopKResult { + Array top_k_; + + CUTLASS_DEVICE + TopKResult() { + top_k_.fill(-cutlass::platform::numeric_limits::infinity()); + } + + // This is where we do the "final" reduction, where we compute + // the logsumexp for softmax, keep the smallest value in top-K, + // and discard the rest. + CUTLASS_DEVICE + ReductionResult reduce_final() const { + return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_)); + } + + // Butterfly reduction + CUTLASS_DEVICE + void shuffle_xor_sync(int laneMask) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); + top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + + // Warp shuffle reduction + CUTLASS_DEVICE + void shuffle_down_sync(uint32_t delta) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); + top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + }; + +public: + struct SharedStorage { }; + + struct Arguments { }; + + struct Params { }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N && N <= epi_N && N >= TopK; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i); + if (elem_less(thread_crd, residue_tCcCol)) { + TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i); + detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { + return; + } + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + // `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N, + // in order to reduce along modes in the `R2S` sublayout that correspond to N. + // This means we should modify and warp-reduce them according to their co-domain instead of + // their domain. Therefore we keep a filtered view of both and use them as necessary. + auto tCrTopK_f = filter(tCrTopK); + auto tCrSoftmax_f = filter(tCrSoftmax); + + // The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the + // last element of Top-K, use the latter to mask the visited results, and the former + // to apply softmax. + // + // This gives us two options: reduce the Top-K with warp shuffles, have the reduced + // lanes compute logsumexp and pair it with the last Top-K element, and broadcast + // the result back using warp shuffles. + // + // Alternatively, we can do a butterfly reduction over Top-K, and have all lanes + // compute their own logsumexp and skip the broadcast. + if constexpr (UseButterflyReduce) { + // + // 1. Butterfly reduction + // + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_xor_sync(j); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + else { + // + // 1. Warp shuffle reduction + // + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols)); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + bool is_reduced_lane = get<1>(lane_mn) == 0; + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + + // + // 3. Broadcast reduced values to all participants + // + CUTLASS_PRAGMA_UNROLL + for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn)); + } + } + } + + // + // 4. Re-visit and apply top-K and softmax + // + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) { + auto& visit_frag = visit_results(epi_v); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + visit_frag[i] = detail::masked_softmax( + visit_frag[i], + tCrSoftmax(epi_v * FragmentSize + i).min_, + tCrSoftmax(epi_v * FragmentSize + i).logsumexp_ + ); + } + } + + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // Reset reduced top-K values for next tile + // This must be done because we only assume a single epilogue tile across N, + // but not M. + fill(tCrTopK, TopKResult()); + } + + CUTLASS_DEVICE void + end() { } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + + // Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction. + static_assert(decltype(size<1>(warp_layout_MN))::value <= 1); + + // Reduction layout + // We're assuming all elements in a row (over which we're performing the reduction) are + // visited in the same corresponding epilogue tile, and this is what allows us to apply the + // top-K + softmax operation within `reduce()`, by re-visiting the accumulated results. + // + // This presents a challenge, because the layout of the accumulated results is typically in + // in the register to shared memory shape, or: (R2S,R2S_M,R2S_N). + // This means that we still need to reduce this tensor along N. + // + // The solution is simple: we need to flatten the layout, identify modes that correspond to + // N and set their strides to 0, in order to map fragment indices corresponding to the same + // row back to the same element in the tensor. + // + // This requires some extra layout manipulation, which is as follows. + + // Create new accumulator layout with column broadcast + auto [M, N, K] = args.tile_shape_mnk; + auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx); + auto gColReduce = make_tensor( + make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N) + auto tCrColReduce = make_tensor_like( // (FrgV, MMA_M, MMA_N) + thr_mma.partition_C(gColReduce).layout()); + + // Tile the new accumulator tensor according to R2S + ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx); + Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) + auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) + + // Compose the new accumulator R2S layout with the expected tCrC layout to get final + // reduction tensor layout. + auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) + + Tensor tCrTopK = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + Tensor tCrSoftmax = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + fill(tCrTopK, TopKResult()); + + auto args_tuple = make_tuple( + cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD, + lane_layout_MN, lane_mn, + args.residue_cD, args.residue_tCcD); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 92407733f8..9f1cd77434 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -178,8 +178,9 @@ struct Clamp { CUTLASS_HOST_DEVICE T operator()(T const& value, T const& lower_bound, T const& upper_bound) const { - maximum mx; - minimum mn; + constexpr bool PropagateNaN = true; + maximum mx; + minimum mn; return mn(mx(value, lower_bound), upper_bound); } @@ -196,8 +197,9 @@ struct Clamp> { CUTLASS_HOST_DEVICE Array operator()(Array const& values, T const& lower_bound, T const& upper_bound) const { - maximum> mx; - minimum> mn; + constexpr bool PropagateNaN = true; + maximum, PropagateNaN> mx; + minimum, PropagateNaN> mn; return mn(mx(values, lower_bound), upper_bound); } @@ -226,7 +228,7 @@ struct LeakyReLU { CUTLASS_HOST_DEVICE T operator()(T const& value, Arguments const& args = Arguments()) const { - this->operator()(value, args.leaky_alpha); + return this->operator()(value, args.leaky_alpha); } }; @@ -696,6 +698,57 @@ struct dReLU_Z> { } }; +// ElementwiseFilter operator +// Filters by a specific value and maps it to 0.0 +// Used in GEMM + comm +template +struct ElementwiseFilter { + + static const bool kIsHeavy = false; + + struct Arguments { + T value_to_filter = T(-0.0); + T filtered_value = T(0.0); + }; + + CUTLASS_HOST_DEVICE + T operator()(T const& value, T const& value_to_filter, T const& filtered_value) const { + T res = value == value_to_filter ? filtered_value : value; + return res; + } + + CUTLASS_HOST_DEVICE + T operator()(T const& value, Arguments const& args = Arguments()) const { + return this->operator()(value, args.value_to_filter, args.filtered_value); + } +}; + +template +struct ElementwiseFilter > { + + static const bool kIsHeavy = false; + + using Arguments = typename ElementwiseFilter::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, T const& value_to_filter, T const& filtered_value) const { + Array y; + ElementwiseFilter filter_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(values.size()); ++i) { + y[i] = filter_op(values[i], value_to_filter, filtered_value); + } + + return y; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, Arguments const& args = Arguments()) const { + return this->operator()(values, args.value_to_filter, args.filtered_value); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 7456ae8df4..c5ffdaa03f 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -127,15 +127,20 @@ class LinearCombinationBiasElementwise { public: using ElementOutput = ElementC_; + using ElementD = ElementOutput; using ElementC = ElementC_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementScalar = ElementCompute; using ElementZ = ElementZ_; using ElementT = ElementT_; using ElementVector = ElementVector_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; + /// Follow cutlass3x EVT aliases + static bool const IsEltActSupported = true; + using ElementwiseOp = ElementwiseOp_; using BinaryOp = BinaryOp_; @@ -157,7 +162,7 @@ class LinearCombinationBiasElementwise { using FragmentOutput = FragmentZ; using ElementBias = ElementVector; using FragmentBias = Array; - using ActivationFunctor = ElementwiseOp; + using ActivationFn = ElementwiseOp; static const ScaleType::Kind kScale = ScaleType::Default; static bool const kIsHeavy = kIsHeavy_member_or_false::value; @@ -396,6 +401,118 @@ class LinearCombinationBiasElementwise { frag_T = convert_t(result_T); } } + + /// Applies the operation when elementwise_op require arguments and is_source_needed() is true + template + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementC const &C, + ElementCompute const &V, + ElementwiseArgs const &elementwise_args) const { + + ElementwiseOp elementwise_op; + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + ElementCompute tmp_C = NumericConverter()(C); + + ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } + + /// Applies the operation when elementwise_op require arguments and is_source_needed() is false + template + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementCompute const &V, + ElementwiseArgs const &elementwise_args) const { + + ElementwiseOp elementwise_op; + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + + ElementCompute z = binary_op(alpha_ * tmp_Accum, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementC const &C, + ElementCompute const &V) const { + + ElementwiseOpDispatcher elementwise_op(elementwise_); + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + ElementCompute tmp_C = NumericConverter()(C); + + ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementCompute const &V) const { + + ElementwiseOpDispatcher elementwise_op(elementwise_); + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + + ElementCompute z = binary_op(alpha_ * tmp_Accum, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 1692cc3093..1d62f4fc35 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -225,6 +225,44 @@ struct DefaultIteratorsTensorOp< static int const kFragmentsPerIteration = 2; }; +/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template < + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + bfloat16_t, + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + WarpShape, + InstructionShape, + int32_t, + 32, + 16, + 8, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + int32_t, + 32, + 16, + 8, + 8 + >; + + static int const kFragmentsPerIteration = 2; +}; + /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. template < typename ThreadblockShape, diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 5709ec9fed..38ea4008c2 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -574,6 +574,12 @@ struct alignas(1) float_e4m3_t : float8_base { int mantissa() const { return int(storage & Base::FP8_MANTISSA_MASK); } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_e4m3_t const& x) { + return x.storage == uint8_t(0x7f); + } + }; /////////////////////////////////////////////////////////////// /// @@ -783,6 +789,12 @@ struct alignas(1) float_e5m2_t : float8_base { int mantissa() const { return int(storage & Base::FP8_MANTISSA_MASK); } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_e5m2_t const& x) { + return x.storage == uint8_t(0x7f); + } + }; /////////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 65e49d5290..5b2bc3c67f 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -38,7 +38,6 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/platform/platform.h" - #if defined(__CUDACC_RTC__) #include "cutlass/floating_point_nvrtc.h" #endif @@ -234,7 +233,7 @@ template <> struct inverse_square_root { CUTLASS_HOST_DEVICE half_t operator()(half_t const &lhs) const { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 520 +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 520) auto result = hrsqrt(reinterpret_cast<__half const &>(lhs)); return reinterpret_cast(result); #else @@ -350,7 +349,19 @@ template struct maximum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { - return (lhs < rhs ? rhs : lhs); + if constexpr (PropagateNaN && cutlass::platform::is_floating_point::value) { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + // Call isnan unqualified, so argument-dependent lookup (ADL) + // will find overloads such as cutlass::isnan(half_t). + // Calling ::isnan or std::isnan directly would force + // implicit conversions to float of custom number types + // in the cutlass namespace (e.g., cutlass::half_t). + return lhs > rhs || isnan(lhs) ? lhs : rhs; + } + else { + return (lhs < rhs ? rhs : lhs); + } } }; @@ -363,23 +374,6 @@ template struct maximum_with_default_nan_propagation : public maximum {}; -// Maximum with nan propagation -// To propagate NANs, the "max" of a two element that contains NaNs should also return a NaN -template -struct maximum { - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { - using CUTLASS_CMATH_NAMESPACE :: isnan; - - // Call isnan unqualified, so argument-dependent lookup (ADL) - // will find overloads such as cutlass::isnan(half_t). - // Calling ::isnan or std::isnan directly would force - // implicit conversions to float of custom number types - // in the cutlass namespace (e.g., cutlass::half_t). - return lhs > rhs || isnan(lhs) ? lhs : rhs; - } -}; - template <> struct maximum { CUTLASS_HOST_DEVICE @@ -391,13 +385,14 @@ struct maximum { template <> struct maximum { CUTLASS_HOST_DEVICE - float operator()(float const lhs, float const rhs) const { + float operator()(float lhs, float rhs) const { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) float res; asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); return res; #else using CUTLASS_CMATH_NAMESPACE :: isnan; + return lhs > rhs || isnan(lhs) ? lhs : rhs; #endif } @@ -418,20 +413,17 @@ template using maximum_with_nan_propogation = maximum_with_nan_propagation; template -struct minimum{ - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { - return (rhs < lhs ? rhs : lhs); - } -}; - -template -struct minimum { +struct minimum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { - using CUTLASS_CMATH_NAMESPACE :: isnan; + if constexpr (PropagateNaN && cutlass::platform::is_floating_point::value) { + using CUTLASS_CMATH_NAMESPACE :: isnan; - return lhs < rhs || isnan(lhs) ? lhs : rhs; + return lhs < rhs || isnan(lhs) ? lhs : rhs; + } + else { + return (rhs < lhs ? rhs : lhs); + } } }; @@ -443,6 +435,21 @@ struct minimum { } }; +template <> +struct minimum { + CUTLASS_HOST_DEVICE + float operator()(float lhs, float rhs) const { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + float res; + asm volatile("min.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); + return res; +#else + // No need for ADL; call std::isnan(float) on host and ::isnan(float) on device. + return lhs < rhs || (CUTLASS_CMATH_NAMESPACE :: isnan(lhs)) ? lhs : rhs; +#endif + } +}; + template struct minimum_with_nan_propagation : minimum {}; @@ -819,9 +826,9 @@ struct atomic_add void operator()(half2 *ptr, const half2 &data) { #if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)) - CUTLASS_UNUSED(ptr); - CUTLASS_UNUSED(data); - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); #else // Vector-2 atomic reduction requires .target sm_60 or higher uint32_t word = reinterpret_cast(data); @@ -879,7 +886,6 @@ struct is_atomic> : platform::true_type {}; template struct is_atomic> : platform::true_type {}; - ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for nvcuda::wmma::fragment diff --git a/include/cutlass/gemm/collective/builders/sm90_common.inl b/include/cutlass/gemm/collective/builders/sm90_common.inl index 298793e886..8d95967f97 100644 --- a/include/cutlass/gemm/collective/builders/sm90_common.inl +++ b/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -38,6 +38,7 @@ #include "cutlass/detail/dependent_false.hpp" #include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/mma_traits_sm90_gmma_sparse.hpp" #include "cute/atom/copy_traits_sm90_tma.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -123,13 +124,12 @@ sm90_cluster_shape_to_tma_atom(UnimodalClusterShape) { } } -// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters. -template +// Generates the most efficient possible TiledCopy with simt copy atom(e.g. cp.async) given a set of parameters. +template constexpr auto -make_cp_async_gmem_tiled_copy() { +make_simt_gmem_tiled_copy() { using namespace cute; - using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>; constexpr int TileSizeMN = cute::size(TileMN{}); constexpr int TileSizeK = cute::size(TileK{}); @@ -144,7 +144,7 @@ make_cp_async_gmem_tiled_copy() { static_assert(ThreadCount % threads_major == 0); static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0)); return make_tiled_copy( - Copy_Atom, Element>{}, + CopyAtom{}, Layout,Int>, Stride, _1>>{}, Layout>>{}); @@ -157,13 +157,12 @@ make_cp_async_gmem_tiled_copy() { static_assert(ThreadCount % threads_major == 0); static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); return make_tiled_copy( - Copy_Atom, Element>{}, + CopyAtom{}, Layout,Int>, Stride< _1,Int>>{}, Layout,_1>>{}); - } - else { - static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); + } else { + static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); } } @@ -319,6 +318,62 @@ ss_smem_selector() } } +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +ss_smem_selector_sparse() +{ + using namespace cute; + + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + if constexpr (major == GMMA::Major::MN) { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_SpAtom{}) == 0) { + return GMMA::Layout_MN_SW128_SpAtom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_SpAtom{}) == 0) { + return GMMA::Layout_MN_SW64_SpAtom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_SpAtom{}) == 0) { + return GMMA::Layout_MN_SW32_SpAtom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_SpAtom{}) == 0) { + return GMMA::Layout_MN_INTER_SpAtom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_SpAtom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_SpAtom{}) == 0) { + return GMMA::Layout_K_SW128_SpAtom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_SpAtom{}) == 0) { + return GMMA::Layout_K_SW64_SpAtom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_SpAtom{}) == 0) { + return GMMA::Layout_K_SW32_SpAtom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_SpAtom{}) == 0) { + return GMMA::Layout_K_INTER_SpAtom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_SpAtom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + template constexpr bool is_input_size_two_bytes() { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 0b3ecb15c6..a4cc768638 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -49,21 +49,21 @@ namespace cutlass::gemm::collective { namespace detail { -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int compute_stage_count_or_override(StageCount stage_count) { return stages; } -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int compute_stage_count_or_override(cute::Int stage_count) { return stages; } -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int compute_stage_count_or_override(StageCountAutoCarveout stage_count) { @@ -78,7 +78,7 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_cou return (CapacityBytes - carveout_bytes) / stage_bytes; } -// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCount stage_count) { @@ -86,16 +86,16 @@ compute_stage_count_or_override_single_affine_transformed_input(StageCount -constexpr int get_bits_for_possibly_void_element() { +constexpr int get_bits_for_possibly_void_element() { if constexpr (cute::is_same_v) { return 0; - } + } else { return sizeof_bits::value; } } -// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout stage_count) { @@ -113,7 +113,7 @@ compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCa static_assert(scale_bytes % 128 == 0, "Scale bytes must be a multiple of 128"); static_assert(zero_bytes % 128 == 0, "Zero bytes must be a multiple of 128"); - // When scales are void, s_bits will be 0 so no smem will be allocated for scales. + // When scales are void, s_bits will be 0 so no smem will be allocated for scales. constexpr int stage_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + @@ -140,7 +140,7 @@ is_warpspecialized_transpose_B(){ cutlass::gemm::detail::is_mn_major_B(); constexpr bool IsWarpSpecialized = cute::is_base_of_v || cute::is_base_of_v || - cute::is_base_of_v || + cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v; @@ -240,8 +240,8 @@ struct CollectiveBuilder< MainloopSm90TmaGmmaWarpSpecializedFP8, MainloopSm90TmaGmmaWarpSpecialized>>; - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -296,7 +296,7 @@ struct CollectiveBuilder< (cute::is_same_v || cute::is_same_v || cute::is_same_v) && - detail::is_use_rmem_A()> + detail::is_use_rmem_A()> > { static_assert(is_static::value); static_assert(is_static::value); @@ -335,8 +335,8 @@ struct CollectiveBuilder< using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized< PipelineStages, ClusterShape_MNK, KernelScheduleType>; - using SmemCopyAtomA = cute::conditional_t>; - using SmemCopyAtomB = cute::conditional_t, void>; + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -404,7 +404,7 @@ public: using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; static_assert(cute::is_tuple::value ^ cute::is_tuple::value || - (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), "Either A OR B must be a tuple or the widths of A and B must be different."); static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; @@ -458,8 +458,8 @@ public: static constexpr int PipelineStages = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); - using SmemCopyAtomA = cute::conditional_t>; - using SmemCopyAtomB = cute::conditional_t, void>; + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; @@ -794,11 +794,16 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = cute::is_same_v ? 2 : 1; - using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomA = decltype(detail::ss_smem_selector< @@ -895,14 +900,19 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = 1; - using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); @@ -913,8 +923,8 @@ struct CollectiveBuilder< using DispatchPolicy = MainloopSm90CpAsyncGmmaRmemAWarpSpecialized< PipelineStages, ClusterShape_MNK, KernelScheduleType>; - using SmemCopyAtomA = cute::conditional_t>; - using SmemCopyAtomB = cute::conditional_t, void>; + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -1025,3 +1035,4 @@ static constexpr bool IsMixedWidthInput = IsDifferentWidth || (IsDifferentWidth } // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl b/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl new file mode 100644 index 0000000000..f9aa7bab2d --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse configs specific for SM90 structure sparse kernels +*/ + + +#pragma once + +#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major +#include "cute/layout.hpp" // cute::Layout, cute::Shape, cute::Stride +#include "cute/numeric/integral_constant.hpp" // cute::Int +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/pointer_sparse.hpp" // cute::is_sparse +#include "cute/util/type_traits.hpp" // cute::is_same_v, cute::conditional_t +#include "cutlass/fast_math.h" // cutlass::round_up +#include "cutlass/layout/matrix.h" // cutlass::RowMajor, cutlass::ColumnMajor + +namespace cutlass { + +using namespace cute; + +template< + class ElementAMma_, + GMMA::Major GmmaMajorA, + class ElementEMma_, + class MinTileShapeK = Int<32> +> +struct Sm90GemmSparseConfig { + + static_assert(cute::is_sparse::value, "ElementAMma MUST be sparse elem"); + static_assert(cute::is_sparse::value, "ElementEMma MUST be sparse elem"); + + // A + using ElementAMma = ElementAMma_; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using ElementAMmaSparsity = Int; + + // Metadata (E) + using ElementEMma = ElementEMma_; + using ElementEMmaRaw = typename ElementEMma::raw_type; + using ElementEMmaSparsity = Int; + + // MMA type + static constexpr bool IsQmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static constexpr bool IsImma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static constexpr bool IsHmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static constexpr bool IsTfmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static_assert(int(IsQmma) + int(IsImma) + int(IsHmma) + int(IsTfmma) == 1, "Ambigious Input Type Config (failed to choose MMA type)"); + + // Number of ElementARaw stored in ElementAMmaRaw. For Hopper this is always 1. + using ElemsARawPerElementAMmaRaw = _1; + + // ElementA Sparsity Ratio + using ElementASparsity = ElementAMmaSparsity; + static_assert(ElementASparsity{} == _2{}, "ElementASparsity must be 2 for Hopper Sparse Gemm"); + + // Logical/Physical ElementA per Chunk + using LogicalElemsAPerChunk = conditional_t; + using PhysicalElemsAPerChunk = Int; + + // Metadata Bits + using ElementEBitsPerChunk = _4; + using ElementEBitsPerElementAMma = cute::conditional_t; + + // Metadata Layout. Unit in corresbonding logical elements. + // Basic metadata block is (16,64) for 8-bit, (16,32) for 16-bit, (16,16) for 32-bit data types. + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-wgmma-metadata-64n32-f16bf16 + // Tensor E layout atom stacks 4 basic blocks along M mode to align with WGMMA instruction shape and + // stacks 1-4 blocks along K mode and reorders memory layout to allow for vectorized loads from smem. + using BlockK = Int<512 / sizeof_bits_v>; + static_assert(MinTileShapeK{} % BlockK{} == 0, "MinTileShapeK must be a multiple of BlockK"); + using NumK = decltype(MinTileShapeK{} / BlockK{}); + + using TensorEAtom_32bit = decltype(make_ordered_layout(Shape, Shape<_8,_2,NumK>>{}, + Step , Step <_0,_4, _2>>{})); + + using TensorEAtom_16bit = decltype(make_ordered_layout(Shape, Shape<_16,_2,NumK>>{}, + Step , Step < _0,_4, _2>>{})); + + using TensorEAtom_8bit = decltype(make_ordered_layout(Shape<_64,MinTileShapeK>{}, + Step < _1, _0>{})); + + using TensorEAtom = cute::conditional_t<(IsQmma || IsImma), TensorEAtom_8bit, + cute::conditional_t>; + + // Logical elems that construct the atomK for tensorE/A. + using TensorEAtomK = Int(TensorEAtom{})>; + using TensorEAtomM = Int(TensorEAtom{})>; + + // Tensor E alignment requirements + using TensorEAlignmentM = TensorEAtomM; + using TensorEAlignmentK = TensorEAtomK; + + // Tensor A alignment requirements + // When A is MN major, TensorAAlignmentK needs to be multiplier of chunk size + // When A is K major, TensorAAlignmentK needs to be multiplier of TMA requirements times tensorA sparsity + // this is b.c. TensorACompressed needs to satisfy TMA requirements + using TensorAAlignmentK = cute::conditional_t>>; + + // When A is MN Major, TensorAAlignmentM needs to be multiplier of TMA requirements + // When A is K Major, no requirements on TensorAAlignmentM. + using TensorAAlignmentM = cute::conditional_t * ElemsARawPerElementAMmaRaw{}>, + _1>; + + // The following two functions are provided for user determine the static layouts type + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutA() { + using LayoutMMajor = Layout, + int32_t>, + Stride, + int64_t>>; + + using LayoutKMajor = Layout, + int32_t>, + Stride, + int64_t>>; + + if constexpr (GmmaMajorA == GMMA::Major::MN) { + return LayoutMMajor{}; + } + else { + return LayoutKMajor{}; + } + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutE() { + return make_layout( + make_shape(make_shape(shape<0>(TensorEAtom{}), int32_t(0)), + make_shape(shape<1>(TensorEAtom{}), int32_t(0)), + int32_t(0)), + make_stride(make_stride(stride<0>(TensorEAtom{}), cute::Int{}), + make_stride(stride<1>(TensorEAtom{}), int64_t(0)), + int64_t(0)) + ); + } + + // This function is used to revert a CuTe layout to a Cutlass layout tag (RowMajor/ColumnMajor) + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutA_tag(Layout layout_a) { + /* + (m, (2, k/2), l) : (2, (1, m*2), m*k) M-major + (m, (2, k/2), l) : (k, (1, 2), m*k) K-major + */ + // Check if the given layout_a is possibly a sparse tensorA layout. + static_assert(rank_v == 3 && depth_v == 2, "Rank and depth mismatch with the sparse tensorA's layout."); + static_assert(rank(get<1>(ShapeA{})) == 2 && rank(flatten(ShapeA{})) == 4, + "Not likely to be a sparse tensorA's layout."); + static_assert(get<1,0>(StrideA{}) == 1 && get<1,0>(ShapeA{}) == ElementASparsity{}, + "Not likely to be a sparse tensorA's layout."); + static_assert(get<0>(StrideA{}) == ElementASparsity{} || get<1,1>(StrideA{}) == ElementASparsity{}, + "Not likely to be a sparse tensorA's layout."); + + if constexpr (get<0>(StrideA{}) == ElementASparsity{}) { + return cutlass::layout::ColumnMajor{}; + } + else { + return cutlass::layout::RowMajor{}; + } + } + + // Fill tensor A layout from dynamic problem shape + template + CUTE_HOST_DEVICE + static constexpr auto + fill_layoutA(ProblemShape problem_shape) { + + const auto [M, N, K, L] = problem_shape; + + // Round up to satisfy TensorA Alignment requirement + const auto M_AlignedAC = cutlass::round_up(M, TensorAAlignmentM{}); + const auto K_AlignedAC = cutlass::round_up(K, TensorAAlignmentK{}); + + if constexpr (GmmaMajorA == GMMA::Major::MN) { + return make_layout( + make_shape(int32_t(M_AlignedAC), + make_shape(ElementASparsity{}, int32_t(K_AlignedAC) / ElementASparsity{}), + int32_t(L)), + make_stride(ElementASparsity{}, + make_stride(_1{}, int64_t(M_AlignedAC) * ElementASparsity{}), + (L == 1) ? int64_t(0) : int64_t(M_AlignedAC * K_AlignedAC)) + ); + } + else { + return make_layout( + make_shape(int32_t(M_AlignedAC), + make_shape(ElementASparsity{}, int32_t(K_AlignedAC / ElementASparsity{})), + int32_t(L)), + make_stride(int64_t(K_AlignedAC), + make_stride(_1{}, ElementASparsity{}), + (L == 1) ? int64_t(0) : int64_t(M_AlignedAC * K_AlignedAC)) + ); + } + } + + // Fill tensor E layout from dynamic problem shape + template + CUTE_HOST_DEVICE + static constexpr auto + fill_layoutE(ProblemShape problem_shape) { + const auto [M, N, K, L] = problem_shape; + + // Round up to satisfy TensorEAlignment requirement + const auto M_AlignedE = cutlass::round_up(M, TensorEAlignmentM{}); + const auto K_AlignedE = cutlass::round_up(K, TensorEAlignmentK{}); + + // TensorEAtom first along m-dim, then along k-dim, then along batch + static_assert(TensorEAlignmentM{} == TensorEAtomM{}, "Shape below assumes TensorEAlignmentM == TensorEAtomM"); + static_assert(TensorEAlignmentK{} == TensorEAtomK{}, "Shape below assumes TensorEAlignmentK == TensorEAtomK"); + + return make_layout( + make_shape(make_shape(shape<0>(TensorEAtom{}), int32_t(M_AlignedE / TensorEAtomM{})), + make_shape(shape<1>(TensorEAtom{}), int32_t(K_AlignedE / TensorEAtomK{})), + int32_t(L)), + make_stride(make_stride(stride<0>(TensorEAtom{}), cute::Int{}), + make_stride(stride<1>(TensorEAtom{}), int64_t(M_AlignedE * TensorEAtomK{})), + (L == 1) ? int64_t(0) : int64_t(M_AlignedE * K_AlignedE)) + ); + } +}; + +} // namespace cutlass diff --git a/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl new file mode 100644 index 0000000000..9b608fe022 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_sparse(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_sparse(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_sparse(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto e_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(e_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS_SPARSE +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input, "FP8 sparse collective currently only supports FastAccum schedules"); + + // For fp32 types, map to tf32 MMA value type + using ElementAMmaRaw = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector_sparse< + ElementAMmaRaw, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementEMma = typename TiledMma::ValTypeE; + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape_MNK{}),_128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + using LayoutPairAE = decltype(cute::make_tuple(LayoutA{}, LayoutE{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector_sparse< + GmmaMajorA, ElementAMmaRaw, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), ElementAMmaSparsity>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_sparse(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + LayoutPairAE, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_SS_FP8_FAST_ACCUM_SPARSE +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v)> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector_sparse< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementEMma = typename TiledMma::ValTypeE; + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape_MNK{}),_128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + using LayoutPairAE = decltype(cute::make_tuple(LayoutA{}, LayoutE{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector_sparse< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), ElementAMmaSparsity>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_sparse(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + LayoutPairAE, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_RS_SPARSE +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + detail::is_use_rmem_A()> +> { + static_assert(cutlass::detail::dependent_false, "Mainloop with sparse A sourced from RF is not implemented."); +}; + +// Sparse GMMA auto kernel schedule +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr bool IsFP8Input = detail::is_input_fp8(); + + using KernelSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + cute::conditional_t, + cute::conditional_t>; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelSchedule + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 532bfecfb1..ccd8d8b3c7 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -38,4 +38,5 @@ #include "cutlass/gemm/collective/collective_builder_decl.hpp" #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" +#include "cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 7bcc075782..103da9af7b 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -43,6 +43,7 @@ #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 75d7bb39e9..9825a16571 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -166,12 +166,12 @@ struct CollectiveMma< size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; - struct TensorMapStorage : cute::aligned_struct<128> { + struct TensorMapStorage : cute::aligned_struct<128, _0> { cute::TmaDescriptor smem_tensormap_A; cute::TmaDescriptor smem_tensormap_B; } tensormaps; @@ -720,7 +720,6 @@ struct CollectiveMma< ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch) { if (cute::elect_one_sync()) { - // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp index 4b291db358..69b31fdabe 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -187,7 +187,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<256> { + struct TensorStorage : cute::aligned_struct<256, _0> { cute::array_aligned, 256> smem_A; cute::array_aligned, 256> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp index 90e7acd38c..e336bd4755 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -135,7 +135,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index 43e05afa07..b30fed1c85 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -213,7 +213,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct { + struct TensorStorage : cute::aligned_struct { cute::array_aligned, SmemAlignmentA> smem_A; cute::array_aligned, SmemAlignmentB> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 1f679c88ca..8c98d15c29 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -174,7 +174,7 @@ struct CollectiveMma< using SmemCopyAtomA = SmemCopyAtomA_; using SmemCopyAtomB = SmemCopyAtomB_; - using SmemCopyAtomScale = Copy_Atom; + using SmemCopyAtomScale = Copy_Atom; // We must ensure the type to be scaled goes to RF static constexpr bool SwapAB = !IsATransformed; @@ -202,6 +202,7 @@ struct CollectiveMma< static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization using ArchTag = typename DispatchPolicy::ArchTag; @@ -273,6 +274,8 @@ struct CollectiveMma< static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; static constexpr auto elements_per_smem_scale() { @@ -304,22 +307,30 @@ struct CollectiveMma< // These methods use some the public members of the class. For that reason, we define them after the public section. static constexpr uint32_t compute_tma_transaction_bytes_mk() { - constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + return cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + static constexpr uint32_t + compute_tma_transaction_bytes_extra() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return baseline_bytes; + return 0; } else if constexpr (ModeHasScales) { constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return baseline_bytes + scale_tx_bytes; + return scale_tx_bytes; } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { // Scale and zero share smem layout constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA - return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + return scale_tx_bytes + zero_tx_bytes; } else { static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); @@ -330,11 +341,6 @@ struct CollectiveMma< } } - static constexpr uint32_t - compute_tma_transaction_bytes_nk() { - return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); - } - public: static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); @@ -349,7 +355,7 @@ struct CollectiveMma< { static constexpr int scale_elements = elements_per_smem_scale(); static constexpr int zero_elements = elements_per_smem_zero(); - struct TensorStorage : cute::aligned_struct { + struct TensorStorage : cute::aligned_struct { cute::ArrayEngine> smem_A; cute::ArrayEngine> smem_B; cute::ArrayEngine smem_scale; @@ -389,14 +395,14 @@ struct CollectiveMma< public: // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy( + using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any - using TMA_Scale = decltype(make_tma_copy( + using TMA_Scale = decltype(make_tma_copy( GmemTiledCopyScale{}, make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), SmemLayoutScale{}(_,_,cute::Int<0>{}), @@ -411,12 +417,12 @@ struct CollectiveMma< _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy( + using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{})); // mcast along M mode for this N load, if any TMA_A tma_load_a; TMA_B tma_load_b; TMA_Scale tma_load_scale; @@ -424,8 +430,7 @@ struct CollectiveMma< int64_t scale_k; int group_size; uint32_t tma_transaction_bytes = TmaTransactionBytes; - uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); }; // @@ -466,31 +471,33 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA)); Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB)); - typename Params::TMA_A tma_load_a = make_tma_copy( + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy( + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{}); // mcast along M mode for this N load, if any - typename Params::TMA_Scale tma_load_scale; - typename Params::TMA_Zero tma_load_zero; + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 }; } else if constexpr (ModeHasScales) { auto scale_k = (K + args.group_size - 1) / args.group_size; ElementScale const* ptr_S = args.ptr_S; StrideScale dS = args.dS; Tensor tensor_scale = make_tensor(get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS)); - tma_load_scale = make_tma_copy( + tma_load_scale = make_tma_copy( GmemTiledCopyScale{}, tensor_scale, SmemLayoutScale{}(_,_,cute::Int<0>{}), @@ -498,7 +505,7 @@ struct CollectiveMma< _1{}); // mcast along N mode for this M load, if any if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) }; } else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS)); @@ -508,7 +515,7 @@ struct CollectiveMma< SmemLayoutScale{}(_,_,cute::Int<0>{}), ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) }; } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } @@ -571,7 +578,8 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + static constexpr uint32_t TmaTransactionBytesExtra = compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -674,122 +682,117 @@ struct CollectiveMma< int lane_predicate = cute::elect_one_sync(); - if (lane_predicate) { - Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) - // - // Prepare the TMA loads for A, B and Scales - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - uint16_t mcast_mask_s = 0; + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } + } - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); } + } - auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); - // - // Copy gmem to smem for *k_tile_iter - // + // + // Copy gmem to smem for *k_tile_iter + // - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // Nothing extra to do. - } - else if constexpr (ModeHasScales) { - auto tSgS = get<0>(extra_input_partitions); - auto tSsS = get<1>(extra_input_partitions); - - // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes - // on the fly. - // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K - // is a multiple of the threadblock tile K - const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); - const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. - copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - // Nothing extra to do - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - auto tZgZ = get<2>(extra_input_partitions); - auto tZsZ = get<3>(extra_input_partitions); - copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + int const scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when group_size == K. + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } - ++k_tile_iter; + ++k_tile_iter; - // Advance smem_pipe_write - ++smem_pipe_write; - } + // Advance smem_pipe_write + ++smem_pipe_write; } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int lane_predicate = cute::elect_one_sync(); - // Issue the epilogue waits - if (lane_predicate) { + if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used @@ -868,13 +871,6 @@ struct CollectiveMma< Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) - // Compute the max vector length that can be used to copy A. This will match the vector width of the - // conversions used. It helps by allowing the compiler to convert using the same register that was used - // to load the data from smem. This significantly reduces the need to move data among registers. - // Note that this is correct even if copy fails to vectorize, since the granularity at which we perform - // the conversion does not impact correctness. - using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); - // Partition of thread -> shared and thread -> RF auto partitioned_extra_info = partition_extra_mma_info(mma_thread_slice, shared_tensors); auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); @@ -915,16 +911,21 @@ struct CollectiveMma< // copy smem->rmem for A operand copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); - - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { - if (k_block < K_BLOCK_MAX - 1) { + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, - partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) @@ -936,11 +937,15 @@ struct CollectiveMma< --k_tile_count; if (k_tile_count > 0) { // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. - warpgroup_wait(); pipeline.consumer_wait(smem_pipe_read, barrier_token); copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + } + warpgroup_wait(); + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); } } @@ -971,9 +976,8 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); if (k_block == K_BLOCK_MAX - 1) { - // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; } @@ -986,12 +990,18 @@ struct CollectiveMma< pipeline.consumer_wait(smem_pipe_read, barrier_token); copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + } + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); } else { - copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, - partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } } warpgroup_fence_operand(accum); @@ -1018,17 +1028,20 @@ struct CollectiveMma< cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); - if (k_block == K_BLOCK_MAX - 1) { - // release prior barrier + if (k_block == K_BLOCK_MAX - 1) { // release prior barrier + warpgroup_wait(); pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; } + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } if (k_block < K_BLOCK_MAX - 1) { copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } } } @@ -1110,10 +1123,20 @@ struct CollectiveMma< // nothing to do return cute::make_tuple(); } + else if constexpr (UseScaleLookupTable) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); + } + } else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); @@ -1121,7 +1144,7 @@ struct CollectiveMma< else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = mma_thread_slice.partition_A(sZ); - Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { @@ -1210,159 +1233,293 @@ struct CollectiveMma< } } } + + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + + CUTLASS_DEVICE + static uint32_t to_reg(Array const& source) { + return static_cast( + reinterpret_cast(source)); + } + CUTLASS_DEVICE + static uint32_t to_reg(Array const& source) { + return reinterpret_cast(source); + } + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template + CUTLASS_DEVICE + static Array lookup_table_convert( + cute::Int _, + Array const& source, + TensorPos const& scale_neg, + TensorNeg const& scale_pos, + int scale_idx) { + + static_assert(N == 4 || N == 8); + uint32_t res[N / 4]; + + // View the input as reg + uint32_t reg = to_reg(source); + + // Determines if to get from the signed or unsigned candidates + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" \ + "}\n" + : "=r"(sign) + : "r"(reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut) + ); + sign = sign >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = reg & 0x77777777; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i, lut_idx >>=16, sign >>=16) { + Array const& _scale_neg = reinterpret_cast const&>(scale_neg[scale_idx + i * 4]); + Array const& _scale_pos = reinterpret_cast const&>(scale_pos[scale_idx + i * 4]); + asm volatile( + "{\n" + " .reg .b32 pos, neg ;\n" \ + " prmt .b32 neg, %3, %4, %1 ;\n" \ + " prmt .b32 pos, %5, %6, %1 ;\n" \ + " prmt .b32 %0, pos, neg, %2 ;\n" \ + "}\n" + : "=r"(res[i]) + : "r"(lut_idx), "r"(sign), "r"(_scale_neg[0]), "r"(_scale_neg[1]), "r"(_scale_pos[0]), "r"(_scale_pos[1]) + ); + } + return reinterpret_cast&>(res); + } + + template + CUTLASS_DEVICE + static void static_check_scale(Layout const& tensor) { + static_assert(shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0, "At least 4 adjacent weights in a thread must share the same scale."); + } + template + CUTLASS_DEVICE + static void static_check_scale(Tensor const& tensor) { + static_check_scale(flatten(Layout{})); + } /// Utilities to transform A. - template CUTLASS_DEVICE void transform_A_kblock( - TCrA_load const& tCrA_load, - cute::Int vec_A, - TCrA_mma& tCrA_mma, + Tensor const& tCrA_load, + Tensor& tCrA_mma, cute::tuple const& partitioned_extra_info, int const k_block) { + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto const& src = tCrA_load(_, _, k_block); + auto const& dst = tCrA_mma(_, _, k_block); + auto pSrc = raw_pointer_cast(src.data()); + auto pDst = const_cast(raw_pointer_cast(dst.data())); + constexpr int num_elements = decltype(size(src))::value; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - transform_internal_A(tCrA_load(_, _, k_block), vec_A, tCrA_mma(_, _, k_block)); + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + } } + else if constexpr (UseScaleLookupTable) { + static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); + static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); + static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); + constexpr int pack = num_elements % 8 == 0? 8 : 4; + constexpr int iters = num_elements / pack; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + auto const& tCrS_neg = cute::get<1>(partitioned_extra_info); + auto const& tCrS_pos = cute::get<2>(partitioned_extra_info); + auto const& scale_neg = tCrS_neg(_, _, k_block); + auto const& scale_pos = tCrS_pos(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scale_neg)); + + static_check_scale(scale_neg); + static_check_scale(scale_pos); + if (k_block == 0) { + auto pNeg = raw_pointer_cast(tCrS_neg.data()); + auto pPos = const_cast(raw_pointer_cast(tCrS_pos.data())); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cosize(tCrS_neg.layout()); ++i) + { + // pPos[i] = pNeg[i] & 0x7F7F7F7F7F7F7F00; + cutlass::Array const& _scale_neg = reinterpret_cast const&>(pNeg[i]); + cutlass::Array & _scale_pos = reinterpret_cast &>(pPos[i]); + asm volatile( + "{\n" + " and .b32 %0, %2, %4 ;\n" \ + " and .b32 %1, %3, %5 ;\n" \ + "}\n" + : "=r"(_scale_pos[0]), "=r"(_scale_pos[1]) + : "r"(_scale_neg[0]), "r"(_scale_neg[1]), "n"(0x7F7F7F00), "n"(0x7F7F7F7F) + ); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; i ++) { + SrcArray const* pSrcArr = reinterpret_cast(raw_pointer_cast(src.data())) + i; + DstArray* pDstArr = reinterpret_cast(raw_pointer_cast(dst.data())) + i; + + *pDstArr = lookup_table_convert(Int{}, *pSrcArr, scale_neg, scale_pos, i * pack); + } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - auto tCrS = cute::get<1>(partitioned_extra_info); - transform_internal_A(tCrA_load(_, _, k_block), vec_A, make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrA_mma(_, _, k_block)); + auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + + if constexpr (is_same_v) { + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + (*pDstArr)[j] = (*pDstArr)[j] * scales[i*pack + j]; + } + } + } + else { + constexpr int pack1 = decltype(select_packing::value())::value; + constexpr int pack2 = decltype(select_packing::value())::value; + constexpr int pack = cute::gcd(pack1, pack2); + using Converter1 = cutlass::NumericArrayConverter; + using Converter2 = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using StageArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + StageArray stageArr; + stageArr = Converter1::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + stageArr[j] = stageArr[j] * scales[i*pack + j]; + } + *pDstArr = Converter2::convert(stageArr); + } + } } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - auto tCrS = cute::get<1>(partitioned_extra_info); - auto tCrZ = cute::get<3>(partitioned_extra_info); - transform_internal_A(tCrA_load(_, _, k_block), - vec_A, - make_fragment_like(tCrA_mma)(_, _, k_block), - tCrS(_, _, 0), - tCrZ(_, _, 0), - make_fragment_like(tCrZ)(_, _, 0), - tCrA_mma(_, _, k_block)); + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + + if constexpr (is_same_v) { + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + (*pDstArr)[j] = (*pDstArr)[j] * scales[i*pack + j] + zeros[i*pack + j]; + } + } + } + else { + constexpr int pack1 = decltype(select_packing::value())::value; + constexpr int pack2 = decltype(select_packing::value())::value; + constexpr int pack = cute::gcd(pack1, pack2); + using Converter1 = cutlass::NumericArrayConverter; + using Converter2 = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using StageArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + StageArray stageArr; + stageArr = Converter1::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + stageArr[j] = stageArr[j] * scales[i*pack + j] + zeros[i*pack + j]; + } + *pDstArr = Converter2::convert(stageArr); + } + } + return; } else { static_assert(cutlass::detail::dependent_false, "No A data is loaded."); } } - - /// Utilities for transforming the A operand prior to issuing tensorcore math. - template > - CUTLASS_DEVICE void - convert_tensor( - Tensor const& in, - Tensor& out, - cute::Int width = {}) { - - /// This is an element-wise conversion where we expect both tensors to have the same layout. - /// As a result, we can cast as a cutlass array to use the fast numeric converters without - /// worrying about indexing into the layout. - constexpr int N = cosize_v; - - /// The inputs must be backed by registers & be statically sized. - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); - static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); - static_assert(is_static_v, "Tensor layout for the conversion must be static"); - static_assert(cosize_v == size(TensorLayout{}), "Cosize and size of the layout must be equal."); - static_assert(N % ConversionVectorWidth == 0, "Conversion vector width must divide cosize of the tensor layout."); - - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - using Converter = cutlass::NumericArrayConverter; - - constexpr int NumIterations = N / ConversionVectorWidth; - - for (int ii = 0; ii < NumIterations; ++ii) { - SrcArray const* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())) + ii; - DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())) + ii; - *dst_array_ptr = Converter::convert(*src_array_ptr); - } - } - - template - CUTLASS_DEVICE void - transform_internal_A( - Tensor&& in, - cute::Int a_vec_width, - Tensor&& out) { - - convert_tensor(in, out, a_vec_width); - } - - template - CUTLASS_DEVICE void - transform_internal_A( - Tensor&& in, - cute::Int a_vec_width, - Tensor&& converted_inputs, - Tensor&& scales, - Tensor&& out) { - - static_assert(cute::is_same_v, - "Type of the engine input buffer must equal the scale buffer"); - - // First, we upcast the inputs to the scale type - convert_tensor(in, converted_inputs, a_vec_width); - - // Apply scales and broadcast across inputs, store in converted_inputs - cute::transform(converted_inputs, scales, converted_inputs, cute::multiplies{}); - - // Finally, we convert the scaled inputs to the mma type. - convert_tensor(converted_inputs, out); - } - - template - CUTLASS_DEVICE void - transform_internal_A( - Tensor&& in, - cute::Int a_vec_width, - Tensor&& converted_inputs, - Tensor&& scales, - Tensor&& zeros, - Tensor&& converted_zeros, - Tensor&& out) { - - static_assert(cute::is_same_v, - "Type of the engine input buffer must equal the scale buffer"); - - static_assert(cute::is_same_v, - "Type of the engine zero buffer must equal the scale buffer"); - - // First, we upcast the inputs to the scale type - convert_tensor(in, converted_inputs, a_vec_width); - convert_tensor(zeros, converted_zeros); - - // Apply scales and broadcast across inputs, store in converted_inputs - cute::transform(converted_inputs, scales, converted_inputs, cute::multiplies{}); - cute::transform(converted_inputs, converted_zeros, converted_inputs, cute::plus{}); - - // Finally, we convert the scaled inputs to the mma type. - convert_tensor(converted_inputs, out); - } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 24af314d5f..b370dc70b5 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -150,7 +150,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 6c02979996..da5274469f 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -144,7 +144,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000..01e83bdf54 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,724 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class LayoutPairAE_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedSparse, + TileShape_, + ElementA_, + LayoutPairAE_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutPairAE = LayoutPairAE_; + using LayoutA = remove_cvref_t(LayoutPairAE{}))>; + using LayoutE = remove_cvref_t(LayoutPairAE{}))>; + using StrideA = decltype(cute::stride(LayoutA{})); + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + static_assert(is_sparse::value, "ElementAMma is sparse"); + static_assert(!is_sparse::value, "ElementA is not sparse"); + + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + // LayoutA is nested in the stride due to the sparsity. + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{}.stride())), Int>; + static constexpr bool is_B_mn_major = cutlass::gemm::detail::is_major<0,StrideB>(); + + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape{}),_128{}))>; + + // The offline permutation for the metadata. + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + // Metadata pathways + using SmemCopyAtomE = AutoVectorizingCopy; + using GmemCopyAtomE = GmemTiledCopyA; + + using CtaShape_MNK = TileShape; + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::sparse_elem, + cutlass::tfloat32_t, + uint_bit_t>>>; + using TmaInternalElementB = cute::conditional_t, + tfloat32_t, + uint_bit_t>>; + + struct SharedStorage + { + struct TensorStorage { + alignas(128) cute::ArrayEngine> smem_A; + alignas(128) cute::ArrayEngine> smem_B; + alignas(128) cute::ArrayEngine> smem_E; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 0; + + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutE{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{}; + LayoutA layout_a{}; + ElementB const* ptr_B{}; + StrideB dB{}; + ElementE const* ptr_E{}; + LayoutE layout_e{}; + }; + + // Device side kernel params + struct Params { + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_E = decltype(make_tma_copy( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + make_tensor(recast_ptr>(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + TMA_A tma_load_a; + TMA_E tma_load_e; + TMA_B tma_load_b; + LayoutA layout_a; + LayoutE layout_e; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + auto ptr_E = recast_ptr>(args.ptr_E); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_E tma_load_e = make_tma_copy( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + return { + tma_load_a, + tma_load_e, + tma_load_b, + args.layout_a, + args.layout_e + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool size_check = true; + // Check Alignment A + if constexpr (is_A_mn_major) { + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(_1{}, M, M*K/2)); + } + else { // If A is K-major + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(K/2, _1{}, M*K/2)); + } + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!size_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + // Check if layout_a and layout_e is filled correctly + auto layout_a_ref = SparseConfig::fill_layoutA(problem_shape_MNKL); + auto layout_e_ref = SparseConfig::fill_layoutE(problem_shape_MNKL); + bool layout_check = true; + layout_check = layout_check && (layout_a_ref == args.layout_a); + layout_check = layout_check && (layout_e_ref == args.layout_e); + + if (!layout_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Layout_a/e mismatch.\n"); + } + + return size_check && layout_check; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_e.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(mainloop_params.layout_a.shape()); // (m,k,l) + Tensor mE_mkl = mainloop_params.tma_load_e.get_tma_tensor(mainloop_params.layout_e.shape()); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gE_mkl = local_tile(mE_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, gE_mkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, class TensorE, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + auto [gA_mkl, gB_nkl, gE_mkl] = load_inputs; + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_e = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_e = mainloop_params.tma_load_e.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(get<0>(cta_coord_mnk)); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gE = gE_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tEgE = block_tma_e.partition_S(gE); // (TMA,TMA_M,TMA_K,k) + Tensor tEsE = block_tma_e.partition_D(sE); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_e.with(*tma_barrier, mcast_mask_e), tEgE(_,_,_,*k_tile_iter), tEsE(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutE{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + Tensor sE_ = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = as_position_independent_swizzle_tensor(sE_); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + auto copy_atom_E = Copy_Atom{}; + + Tensor tCsE = partition_E(thread_mma, sE(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrE = make_fragment_like(tCsE); // (MMA,MMA_M,MMA_K) + + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(thread_idx); + + Tensor tEsE = smem_thr_copy_E.partition_S(sE); // (ECPY,ECPY_M,ECPY_K) + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); // (ECPY,ECPY_M,ECPY_K) + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + +private: + + template + CUTE_HOST_DEVICE static constexpr + auto + thrfrg_E(TiledMMA const& mma, ETensor&& etensor) + { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = e_tensor.compose(AtomLayoutE_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + get_layoutE_TV(TiledMMA const& mma) + { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(mma.thr_layout_vmnk_), size<2>(mma.thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + partition_E(ThrMMA const& thr_mma, ETensor&& etensor) + { + auto thr_tensor = make_tensor(static_cast(etensor).data(), thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_mma.thr_vmnk_), make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_tiled_copy_E(Copy_Atom const& copy_atom, + TiledMMA const& mma) + { + return make_tiled_copy_impl(copy_atom, get_layoutE_TV(mma), make_shape(tile_size<0>(mma),tile_size<2>(mma))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/base_grouped.h b/include/cutlass/gemm/device/base_grouped.h index 51b9d3dc10..eec61981f8 100644 --- a/include/cutlass/gemm/device/base_grouped.h +++ b/include/cutlass/gemm/device/base_grouped.h @@ -432,6 +432,7 @@ class BaseGrouped { // // Launch + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); // diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index c9e7cc76d1..e7ed2da940 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -764,50 +764,19 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// -/// Base configuration for all {fe4m3, fe5m2} x {fe4m3, fe5m2} combinations on SM89 template < - typename ElementA, - typename ElementB, - typename ElementC, - typename ElementAccumulator> -struct DefaultGemmConfigurationSm89F8 { - static_assert((platform::is_same::value || - platform::is_same::value), - "ElementA must be of type float_e4m3_t or float_e5m2_t"); - static_assert((platform::is_same::value || - platform::is_same::value), - "ElementB must be of type float_e4m3_t or float_e5m2_t"); - - static int const kAlignmentA = 128 / sizeof_bits::value; - static int const kAlignmentB = 128 / sizeof_bits::value; - - using ThreadblockShape = GemmShape<128, 256, 64>; - using WarpShape = GemmShape<64, 64, 64>; - using InstructionShape = GemmShape<16, 8, 32>; - static int const kStages = 3; - - using EpilogueOutputOp = epilogue::thread::LinearCombination< - ElementC, 128 / sizeof_bits::value, ElementAccumulator, - ElementAccumulator>; - - using Operator = arch::OpMultiplyAdd; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < typename ElementC> struct DefaultGemmConfiguration< - arch::OpClassTensorOp, - arch::Sm80, - int4b_t, - int8_t, - ElementC, + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int8_t, + ElementC, int32_t> { - + static int const kAlignmentA = 128 / sizeof_bits::value; static int const kAlignmentB = 128 / sizeof_bits::value; - + using ThreadblockShape = GemmShape<128, 256, 64>; using WarpShape = GemmShape<64, 64, 64>; using InstructionShape = GemmShape<16, 8, 32>; @@ -821,19 +790,19 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// -template < +template < typename ElementC> struct DefaultGemmConfiguration< - arch::OpClassTensorOp, - arch::Sm80, - int8_t, - int4b_t, - ElementC, + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int4b_t, + ElementC, int32_t> { - + static int const kAlignmentA = 128 / sizeof_bits::value; static int const kAlignmentB = 128 / sizeof_bits::value; - + using ThreadblockShape = GemmShape<128, 256, 64>; using WarpShape = GemmShape<64, 64, 64>; using InstructionShape = GemmShape<16, 8, 32>; @@ -847,6 +816,35 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// +/// Base configuration for all {fe4m3, fe5m2} x {fe4m3, fe5m2} combinations on SM89 +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfigurationSm89F8 { + static_assert((platform::is_same::value || + platform::is_same::value), + "ElementA must be of type float_e4m3_t or float_e5m2_t"); + static_assert((platform::is_same::value || + platform::is_same::value), + "ElementB must be of type float_e4m3_t or float_e5m2_t"); + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + /// Partial specialization for SM89 fe4m3 x fe4m3 template struct DefaultGemmConfiguration< diff --git a/include/cutlass/gemm/device/ell_gemm.h b/include/cutlass/gemm/device/ell_gemm.h index f5b65cea29..54ddab4007 100644 --- a/include/cutlass/gemm/device/ell_gemm.h +++ b/include/cutlass/gemm/device/ell_gemm.h @@ -517,6 +517,7 @@ class EllGemm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index f0226354de..c6f488b146 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -491,6 +491,7 @@ class Gemm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index 6bbd90c1cd..1ae2db467f 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -446,6 +446,7 @@ class GemmArray { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 3be34c808d..5981457c73 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -424,6 +424,7 @@ class GemmBatched { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 36f57d6469..e36c69cefb 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -445,6 +445,7 @@ class GemmComplex { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index 1b1d27bda5..ac453c63b5 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -479,6 +479,7 @@ class SparseGemm { int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_sparse_with_absmax.h b/include/cutlass/gemm/device/gemm_sparse_with_absmax.h index e6db107604..e599217a13 100644 --- a/include/cutlass/gemm/device/gemm_sparse_with_absmax.h +++ b/include/cutlass/gemm/device/gemm_sparse_with_absmax.h @@ -324,6 +324,7 @@ class SparseGemmWithAbsmax { int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index 2c9408df0e..f78c5a2169 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -357,6 +357,7 @@ class GemmSplitKParallel { } } + cutlass::arch::synclog_setup(); Kernel<<>>(gemm_params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 40094dcb10..73564d3c65 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -44,6 +44,7 @@ #include "cutlass/detail/mma.hpp" #include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/kernel_launch.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" #include "cutlass/trace.h" @@ -211,9 +212,10 @@ class GemmUniversalAdapter< workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); } + workspace_bytes += GemmKernel::get_workspace_size(args); + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - workspace_bytes += GemmKernel::get_workspace_size(args); return workspace_bytes; } @@ -350,8 +352,12 @@ class GemmUniversalAdapter< Status launch_result{ Status::kSuccess }; // Use extended launch API only for mainloops that use it if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { - constexpr bool is_static_1x1x1 = cute::is_static_v and - cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); @@ -363,12 +369,14 @@ class GemmUniversalAdapter< // CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - if (launch_with_pdl) { CUTLASS_TRACE_HOST( "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); return Status::kErrorInternal; } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif launch_result = cuda_adapter->launch(grid, cluster, block, @@ -378,6 +386,7 @@ class GemmUniversalAdapter< 0); } else { + CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); return Status::kErrorInternal; } } @@ -385,10 +394,25 @@ class GemmUniversalAdapter< CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { - if (is_static_1x1x1 && not launch_with_pdl) { - device_kernel<<>>(params); + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif } else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); +#endif launch_result = ClusterLauncher::launch( grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); } @@ -397,28 +421,48 @@ class GemmUniversalAdapter< } else { launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { void* kernel_params[] = {¶ms}; - +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif launch_result = cuda_adapter->launch( grid, block, smem_size, stream, kernel_params, 0 ); } else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); return Status::kErrorInternal; } } else { CUTLASS_ASSERT(cuda_adapter == nullptr); - device_kernel<<>>(params); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif } } cudaError_t result = cudaGetLastError(); if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif return Status::kSuccess; } else { diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 63da07b418..e23191eae5 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -443,6 +443,8 @@ class GemmUniversalBase { "block: (" << block << "), " "SMEM: (" << kSharedStorageSize << ")"); + cutlass::arch::synclog_setup(); + if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { diff --git a/include/cutlass/gemm/device/gemv.h b/include/cutlass/gemm/device/gemv.h index 341124942a..5e181743ef 100644 --- a/include/cutlass/gemm/device/gemv.h +++ b/include/cutlass/gemm/device/gemv.h @@ -141,6 +141,7 @@ class Gemv { int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); // Launch + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); // diff --git a/include/cutlass/gemm/device/rank_2k.h b/include/cutlass/gemm/device/rank_2k.h index d12621e6b9..296f38cad2 100644 --- a/include/cutlass/gemm/device/rank_2k.h +++ b/include/cutlass/gemm/device/rank_2k.h @@ -319,6 +319,7 @@ class Rank2K { int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/rank_k.h b/include/cutlass/gemm/device/rank_k.h index e6e9d025a4..ae18a11b80 100644 --- a/include/cutlass/gemm/device/rank_k.h +++ b/include/cutlass/gemm/device/rank_k.h @@ -296,6 +296,7 @@ class RankK { int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/symm.h b/include/cutlass/gemm/device/symm.h index 223e1b0d10..c36ef959b1 100755 --- a/include/cutlass/gemm/device/symm.h +++ b/include/cutlass/gemm/device/symm.h @@ -337,6 +337,7 @@ class Symm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/trmm.h b/include/cutlass/gemm/device/trmm.h index e354e7a132..09b9152cbb 100644 --- a/include/cutlass/gemm/device/trmm.h +++ b/include/cutlass/gemm/device/trmm.h @@ -495,6 +495,7 @@ class Trmm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index c1c2308b9d..904e6af3cc 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -34,7 +34,7 @@ #include "cutlass/gemm/gemm.h" #include "cute/layout.hpp" -#include "cute/numeric/integral_constant.hpp" +#include "cute/numeric/integral_constant.hpp" // cute::false_type ////////////////////////////////////////////////////////////////////////////// namespace cutlass::detail { @@ -48,6 +48,16 @@ struct is_kernel_tag_of, U> : cute::true_type {}; template class U> constexpr bool is_kernel_tag_of_v = is_kernel_tag_of::value; +template class U> +struct is_asymmetric_dma_kernel_tag_of : cute::false_type {}; + +template