From 1825ef883bbfdb8f2546a02a25609a37ac2ecdc5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Dec 2024 14:36:05 +0000 Subject: [PATCH 1/9] Cutlass grouped gemm files Signed-off-by: ElizaWszola --- CMakeLists.txt | 9 +- csrc/cpu/torch_bindings.cpp | 7 + csrc/ops.h | 8 + .../cutlass_w8a8/grouped_gemm_test.cu | 397 ++++++++++++++++++ .../cutlass_w8a8/scaled_mm_entry.cu | 20 + csrc/torch_bindings.cpp | 8 + tests/kernels/test_cutlass.py | 68 +++ 7 files changed, 514 insertions(+), 3 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 5acbd762ee957..9d6185e756338 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -209,13 +209,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + # GIT_TAG v3.5.1 + GIT_TAG dbdae514e03f83968f8b7dd4fb064071b9bfbdd1 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(cutlass) @@ -261,7 +262,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 03beefbc6de7d..d6c32322ff592 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,6 +118,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); +// CUTLASS w8a8 grouped GEMM // TODO complete this + ops.def( + "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " + " Tensor b_scales, Tensor problem_sizes, " + " Tensor out_offsets, Tensor a_offsets, " + " Tensor b_offsets) -> ()"); + ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/csrc/ops.h b/csrc/ops.h index 672e608e9c47e..fce4346fa4218 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,6 +145,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); +void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu new file mode 100644 index 0000000000000..8e46b9a33cea3 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -0,0 +1,397 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" + +// TODO let's see which of these we'll need + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +#include "common.hpp" + +// get rid of these? +// #include "helper.h" +// using namespace cute; + +using namespace cute; + +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +namespace { + + // A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementAB_Type = cutlass::float_e4m3_t; // Element type for A matrix operand +// using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ElementC_Type = cutlass::half_t; + +// // A matrix configuration +// using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +// constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// // B matrix configuration +// using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +// constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// // C/D matrix configuration +// using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +// constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +// using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Different configs for pingpong/cooperative +// struct CooperativeConfig { +// using KernelSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; +// using EpilogueSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperative; +// using TileShape = cute::Shape; +// using ClusterShape = cute::Shape; +// }; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::ColumnMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_group_gemm { + + using ElementAB = ElementAB_; + using ElementC = ElementC_; + using ElementAccumulator = float; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementC, + ElementC, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideC = cute::remove_pointer_t, cute::Int<0>>>; + + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + // the orig hat cutlass::epilogue::fusion::LinearCombination + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, 4, + ElementC, LayoutC*, 4, + EpilogueSchedule, EVTCompute + >::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementAB, LayoutA*, 16, + ElementAB, LayoutB*, 16, + ElementAccumulator, + TileShape, ClusterShape, + Stages, KernelSchedule + >::CollectiveOp; + + using KernelType = enable_sm90_or_later>; + + struct GemmKernel : public KernelType {}; +}; + +template +struct ItemDeleter { + void operator()(T* ptr) { + cudaFree(ptr); // noexcept + } +}; + +template +void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + // using ElementC = typename Gemm::ElementC; + using ElementC = typename Gemm::ElementC; + using ElementAcc = float; + + int groups = problem_sizes.size(0); + std::vector a_ptrs_host(groups); + std::vector b_ptrs_host(groups); + std::vector c_ptrs_host(groups); + std::vector d_ptrs_host(groups); + + for (int g = 0; g < groups; ++g) { + a_ptrs_host.at(g) = (ElementAB*)a.data_ptr();// + a_offsets[g].item(); + b_ptrs_host.at(g) = (ElementAB*)b.data_ptr();// + b_offsets[g].item(); + c_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + d_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + } + + // int32_t groups = a.size(0); + // int32_t m = a.size(1); + // int32_t n = b.size(2); + // int32_t k = a.size(2); + + // int64_t lda = a.stride(1); + // int64_t ldb = b.stride(2); + // int64_t ldc = out.stride(1); + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // StrideA stride_A{lda, cute::Int<1>{}, 0}; + // StrideB stride_B{ldb, cute::Int<1>{}, 0}; + // StrideC stride_C{ldc, cute::Int<1>{}, cute::Int<0>{}}; + + // this should be vector of A ptrs + // auto ptr_A = static_cast(a.data_ptr()); + // auto ptr_B = static_cast(b.data_ptr()); + // auto ptr_C = static_cast(out.data_ptr()); + + cutlass::platform::unique_ptr stride_A; + cutlass::platform::unique_ptr stride_B; + cutlass::platform::unique_ptr stride_C; + cutlass::platform::unique_ptr stride_D; + + cutlass::platform::unique_ptr ptr_A; + cutlass::platform::unique_ptr ptr_B; + cutlass::platform::unique_ptr ptr_C; + cutlass::platform::unique_ptr ptr_D; + + using GemmKernel = typename Gemm::GemmKernel; + + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; + + std::vector problem_sizes_host; + problem_sizes_host.reserve(groups); + for (int32_t g = 0; g < groups; ++g) { + int32_t m = problem_sizes[g][0].item(); + int32_t n = problem_sizes[g][1].item(); + int32_t k = problem_sizes[g][2].item(); + problem_sizes_host.push_back({m, n, k}); + } + + SingleProblemShape* problem_sizes_device; + int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); + cudaMalloc(&problem_sizes_device, problem_sizes_size); + cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), groups, + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> problem_sizes_ptr( + problem_sizes_device); + ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; + + const ElementAB** a_ptrs_device; + cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> a_ptrs_ptr( + a_ptrs_device + ); + + const ElementAB** b_ptrs_device; + cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> b_ptrs_ptr( + b_ptrs_device + ); + + const ElementC** c_ptrs_device; + cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> c_ptrs_ptr( + c_ptrs_device + ); + + ElementC** d_ptrs_device; + cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> d_ptrs_ptr( + d_ptrs_device + ); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptrs_ptr.get(), stride_A.get(), b_ptrs_ptr.get(), stride_B.get()}; + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptrs_ptr.get(), stride_C.get(), d_ptrs_ptr.get(), stride_D.get()}; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + prob_shape, + mainloop_args, + epilogue_args, + hw_info + }; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + // // auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); + + // #if defined(ENABLE_SM90_KERNEL_LEVEL) + // printf("did run through\n"); + cutlass::Status status = gemm_op.run(); + CUTLASS_CHECK(status); + // #endif + +} + +// typedef InType = cutlass::float_e4m3_t; +// typedef OutType = torch::half; +// typedef Epilogue = ScaledEpilogueBias; + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +} + +// TODO hardcode types here? +void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets) { + + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + // int32_t m = a.size(1); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config_default::Cutlass3xGemm; + // using Cutlass3xGemmM64 = + // typename sm90_fp8_config_M64::Cutlass3xGemm; + // using Cutlass3xGemmM128 = + // typename sm90_fp8_config_M128::Cutlass3xGemm; + + + // // uint32_t const m = a.size(0); + // uint32_t const mp2 = + // std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + // if (mp2 <= 64) { + // // m in [1, 64] + // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // } else if (mp2 <= 128) { + // // m in (64, 128] + // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // } else { + // // m in (128, inf) + cutlass_group_gemm_caller(out, a, b, problem_sizes, + out_offsets, a_offsets, b_offsets, a_scales, b_scales); + // } + +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..78225f9b0db0a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -27,6 +27,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); + +void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets); + #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -151,6 +160,17 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } +void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets) { + cutlass_grouped_mm_sm90(out, a, b, a_scales, b_scales, problem_sizes, + out_offsets, a_offsets, b_offsets); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e4cc7ec951848..a10c661b22a6a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -311,6 +311,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // CUTLASS w8a8 grouped GEMM // TODO complete this + ops.def( + "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " + " Tensor b_scales, Tensor problem_sizes, " + " Tensor out_offsets, Tensor a_offsets, " + " Tensor b_offsets) -> ()"); + ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index afe53797322f9..6228c908545d0 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -6,6 +6,7 @@ import pytest import torch +import random from tests.kernels.utils import opcheck from vllm import _custom_ops as ops @@ -453,3 +454,70 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + +# TODO fix scales +@pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) +@pytest.mark.parametrize("num_groups", [10]) +@pytest.mark.parametrize("per_act_token", [False])# [True, False]) +@pytest.mark.parametrize("per_out_ch", [True])# [True, False]) +@pytest.mark.parametrize("use_bias", [False])# [True, False]) +@pytest.mark.skipif(not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, + per_act_token: bool, + per_out_ch: bool, use_bias: bool): + + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + device = "cuda" + out_dtype = torch.half + + alignment = 16 # 128 // 8 + problem_sizes = torch.empty((num_groups, 3), device="cpu") + offsets_a = torch.empty((num_groups), device="cpu") + offsets_b = torch.empty((num_groups), device="cpu") + offsets_c = torch.empty((num_groups), device="cpu") + tot_a = 0 + tot_b = 0 + tot_c = 0 + for g in range(num_groups): + m = alignment * random.randint(1, 64) + n = alignment * random.randint(1, 64) + k = alignment * random.randint(1, 64) + tot_a += m * k + tot_b += k * n + tot_c += m * n + offsets_a[g] = m * k + offsets_b[g] = k * n + offsets_c[g] = m * n + problem_sizes[g][0] = m + problem_sizes[g][1] = n + problem_sizes[g][2] = k + + a = to_fp8(torch.randn((tot_a), device=device)) + b = to_fp8(torch.randn((tot_b), device=device).t()) + c = torch.zeros((tot_c), device=device).to(out_dtype) + + m_a_scales = m if per_act_token else 1 + n_b_scales = n if per_out_ch else 1 + + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) + if use_bias: + bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 + else: + bias = None + + # TODO strides we can get later the same way as in scaled_mm_c3x.cu + torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, + offsets_c, offsets_a, offsets_b) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + print(c) + + # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) + + # opcheck(torch.ops._C.cutlass_scaled_mm, + # (out, a, b, scale_a, scale_b, bias)) From 5fd48e5b4270cc43428f149eb731ec117b2afec8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 9 Dec 2024 12:20:50 +0000 Subject: [PATCH 2/9] runs, bad result Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_gemm_test.cu | 107 ++++++++---------- tests/kernels/test_cutlass.py | 34 +++--- 2 files changed, 68 insertions(+), 73 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 8e46b9a33cea3..004599c2b5d26 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -58,35 +58,14 @@ using ElementAB_Type = cutlass::float_e4m3_t; // using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand using ElementC_Type = cutlass::half_t; -// // A matrix configuration -// using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -// constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) - -// // B matrix configuration -// using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand -// constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) - -// // C/D matrix configuration -// using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands -// constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) - // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -// using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size - -// Different configs for pingpong/cooperative -// struct CooperativeConfig { -// using KernelSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; -// using EpilogueSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperative; -// using TileShape = cute::Shape; -// using ClusterShape = cute::Shape; -// }; -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::ColumnMajor; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::ColumnMajor; template typename Epilogue_, @@ -107,8 +86,8 @@ struct cutlass_3x_group_gemm { using StrideC = cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; // the orig hat cutlass::epilogue::fusion::LinearCombination @@ -172,34 +151,25 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = (ElementAB*)a.data_ptr();// + a_offsets[g].item(); - b_ptrs_host.at(g) = (ElementAB*)b.data_ptr();// + b_offsets[g].item(); - c_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); - d_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item(); + b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item(); + c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); + d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); } - // int32_t groups = a.size(0); - // int32_t m = a.size(1); - // int32_t n = b.size(2); - // int32_t k = a.size(2); - - // int64_t lda = a.stride(1); - // int64_t ldb = b.stride(2); - // int64_t ldc = out.stride(1); - using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideC = typename Gemm::GemmKernel::InternalStrideC; using StrideD = typename Gemm::GemmKernel::InternalStrideD; - // StrideA stride_A{lda, cute::Int<1>{}, 0}; - // StrideB stride_B{ldb, cute::Int<1>{}, 0}; - // StrideC stride_C{ldc, cute::Int<1>{}, cute::Int<0>{}}; + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); - // this should be vector of A ptrs - // auto ptr_A = static_cast(a.data_ptr()); - // auto ptr_B = static_cast(b.data_ptr()); - // auto ptr_C = static_cast(out.data_ptr()); + std::vector a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}}); + std::vector b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}}); + // TODO fix + std::vector c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}}); cutlass::platform::unique_ptr stride_A; cutlass::platform::unique_ptr stride_B; @@ -212,7 +182,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, cutlass::platform::unique_ptr ptr_D; using GemmKernel = typename Gemm::GemmKernel; - + cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish // to use a GPU other than that with device ID 0. @@ -241,38 +211,60 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, const ElementAB** a_ptrs_device; cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> a_ptrs_ptr( a_ptrs_device ); const ElementAB** b_ptrs_device; cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> b_ptrs_ptr( b_ptrs_device ); const ElementC** c_ptrs_device; cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> c_ptrs_ptr( c_ptrs_device ); + // TODO if we start with empty values here, no need to copy ElementC** d_ptrs_device; cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> d_ptrs_ptr( d_ptrs_device ); + StrideA* a_stride_device; + cudaMalloc(&a_stride_device, groups * sizeof(StrideA*)); + cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> a_stride_ptr( + a_stride_device + ); + + StrideB* b_stride_device; + cudaMalloc(&b_stride_device, groups * sizeof(StrideB*)); + cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> b_stride_ptr( + b_stride_device + ); + + StrideC* c_stride_device; + cudaMalloc(&c_stride_device, groups * sizeof(StrideC*)); + cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> c_stride_ptr( + c_stride_device + ); + typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), stride_A.get(), b_ptrs_ptr.get(), stride_B.get()}; + a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptrs_ptr.get(), stride_C.get(), d_ptrs_ptr.get(), stride_D.get()}; + c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, @@ -296,11 +288,8 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); - // #if defined(ENABLE_SM90_KERNEL_LEVEL) - // printf("did run through\n"); - cutlass::Status status = gemm_op.run(); - CUTLASS_CHECK(status); - // #endif + cutlass::Status status = gemm_op.run(); + CUTLASS_CHECK(status); } @@ -367,7 +356,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - // int32_t m = a.size(1); + // int32_t m = a.size(1); using Cutlass3xGemmDefault = typename sm90_fp8_config_default Date: Tue, 10 Dec 2024 15:24:17 +0000 Subject: [PATCH 3/9] A little closer to working Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 2 +- .../cutlass_w8a8/grouped_gemm_test.cu | 305 +++++++++--------- .../cutlass_w8a8/scaled_mm_entry.cu | 12 +- tests/kernels/test_cutlass.py | 87 +++-- 4 files changed, 224 insertions(+), 182 deletions(-) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index d6c32322ff592..80a326cdc5ef4 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,7 +118,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); -// CUTLASS w8a8 grouped GEMM // TODO complete this + // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " " Tensor b_scales, Tensor problem_sizes, " diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 004599c2b5d26..db86bd1a4b466 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -33,12 +33,12 @@ using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 -#define ENABLE_SM90_KERNEL_LEVEL 1 + #define ENABLE_SM90_KERNEL_LEVEL 1 #endif namespace { - // A wrapper for the GEMM kernel that is used to guard against compilation on +// A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef @@ -47,32 +47,36 @@ template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); - #endif +#endif } }; -using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group -using ElementAB_Type = cutlass::float_e4m3_t; // Element type for A matrix operand -// using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // + // per group +using ElementAB_Type = + cutlass::float_e4m3_t; // Element type for A matrix operand +// using ElementB = cutlass::float_e4m3_t; // +// Element type for B matrix operand using ElementC_Type = cutlass::half_t; // Core kernel configurations -using ElementAccumulator = float; // Element type for internal accumulation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::ColumnMajor; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_group_gemm { - using ElementAB = ElementAB_; using ElementC = ElementC_; using ElementAccumulator = float; @@ -84,42 +88,36 @@ struct cutlass_3x_group_gemm { using Epilogue = Epilogue_; - using StrideC = cute::remove_pointer_t, cute::Int<0>>>; + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; - // the orig hat cutlass::epilogue::fusion::LinearCombination - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, - ElementC, LayoutC*, 4, - ElementC, LayoutC*, 4, - EpilogueSchedule, EVTCompute - >::CollectiveOp; + // the orig hat cutlass::epilogue::fusion::LinearCombination + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, 4, ElementC, LayoutC*, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementAB, LayoutA*, 16, - ElementAB, LayoutB*, 16, - ElementAccumulator, - TileShape, ClusterShape, - Stages, KernelSchedule - >::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, 16, ElementAB, LayoutB*, + 16, ElementAccumulator, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp; using KernelType = enable_sm90_or_later>; + ProblemShape, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; }; @@ -127,20 +125,19 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder template struct ItemDeleter { void operator()(T* ptr) { - cudaFree(ptr); // noexcept + cudaFree(ptr); // noexcept } }; template void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets, - EpilogueArgs&&... epilogue_params) { + torch::Tensor const& b, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets, + EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; - // using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC; using ElementAcc = float; @@ -151,43 +148,48 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item(); - b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item(); - c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); - d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); + a_ptrs_host.at(g) = + static_cast(a.data_ptr()) + a_offsets[g].item(); + b_ptrs_host.at(g) = + static_cast(b.data_ptr()) + b_offsets[g].item(); + c_ptrs_host.at(g) = + static_cast(out.data_ptr()) + out_offsets[g].item(); + d_ptrs_host.at(g) = + static_cast(out.data_ptr()) + out_offsets[g].item(); + printf("%d %d %d\n", a_offsets[g].item(), + b_offsets[g].item(), out_offsets[g].item()); } - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - std::vector a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}}); - std::vector b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}}); - // TODO fix - std::vector c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}}); + using GemmKernel = typename Gemm::GemmKernel; - cutlass::platform::unique_ptr stride_A; - cutlass::platform::unique_ptr stride_B; - cutlass::platform::unique_ptr stride_C; - cutlass::platform::unique_ptr stride_D; + using StrideA = typename GemmKernel::InternalStrideA; + using StrideB = typename GemmKernel::InternalStrideB; + using StrideC = typename GemmKernel::InternalStrideC; + // using StrideD = typename GemmKernel::InternalStrideD; - cutlass::platform::unique_ptr ptr_A; - cutlass::platform::unique_ptr ptr_B; - cutlass::platform::unique_ptr ptr_C; - cutlass::platform::unique_ptr ptr_D; + std::vector a_stride_host(groups); + std::vector b_stride_host(groups); + std::vector c_stride_host(groups); - using GemmKernel = typename Gemm::GemmKernel; + for (int g = 0; g < groups; ++g) { + int32_t m = problem_sizes[g][0].item(); + int32_t n = problem_sizes[g][1].item(); + int32_t k = problem_sizes[g][2].item(); + a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, + // row + b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, + // col + c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, + // row + } cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with multiple GPUs and wish - // to use a GPU other than that with device ID 0. + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; @@ -203,76 +205,83 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, SingleProblemShape* problem_sizes_device; int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); cudaMalloc(&problem_sizes_device, problem_sizes_size); - cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), groups, - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> problem_sizes_ptr( - problem_sizes_device); - ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; + cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), + groups * sizeof(SingleProblemShape), cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + problem_sizes_ptr(problem_sizes_device); + ProblemShape prob_shape{groups, problem_sizes_ptr.get(), + problem_sizes_host.data()}; + + // ElementAB* a_host_print; + // int numel = a.numel(); + // cudaMalloc(&a_host_print, groups * sizeof(ElementAB)); + // cudaMemcpy(a_host_print, static_cast(a.data_ptr()), numel* + // sizeof(ElementAB), cudaMemcpyDeviceToHost); + // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* + // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); const ElementAB** a_ptrs_device; cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> a_ptrs_ptr( - a_ptrs_device - ); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups * sizeof(ElementAB*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + a_ptrs_ptr(a_ptrs_device); const ElementAB** b_ptrs_device; cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> b_ptrs_ptr( - b_ptrs_device - ); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups * sizeof(ElementAB*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + b_ptrs_ptr(b_ptrs_device); const ElementC** c_ptrs_device; cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> c_ptrs_ptr( - c_ptrs_device - ); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups * sizeof(ElementC*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + c_ptrs_ptr(c_ptrs_device); - // TODO if we start with empty values here, no need to copy ElementC** d_ptrs_device; cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups * sizeof(ElementC*), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> d_ptrs_ptr( - d_ptrs_device - ); + d_ptrs_device); StrideA* a_stride_device; - cudaMalloc(&a_stride_device, groups * sizeof(StrideA*)); - cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&a_stride_device, groups * sizeof(StrideA)); + cudaMemcpy(a_stride_device, a_stride_host.data(), groups * sizeof(StrideA), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> a_stride_ptr( - a_stride_device - ); + a_stride_device); StrideB* b_stride_device; - cudaMalloc(&b_stride_device, groups * sizeof(StrideB*)); - cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&b_stride_device, groups * sizeof(StrideB)); + cudaMemcpy(b_stride_device, b_stride_host.data(), groups * sizeof(StrideB), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> b_stride_ptr( - b_stride_device - ); + b_stride_device); StrideC* c_stride_device; - cudaMalloc(&c_stride_device, groups * sizeof(StrideC*)); - cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&c_stride_device, groups * sizeof(StrideC)); + cudaMemcpy(c_stride_device, c_stride_host.data(), groups * sizeof(StrideC), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> c_stride_ptr( - c_stride_device - ); + c_stride_device); typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; + a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), + b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; + c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), + c_stride_ptr.get()}; typename GemmKernel::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - prob_shape, - mainloop_args, - epilogue_args, - hw_info - }; + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args, hw_info}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; @@ -284,18 +293,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); - // // auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); - - cutlass::Status status = gemm_op.run(); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); - } // typedef InType = cutlass::float_e4m3_t; // typedef OutType = torch::half; -// typedef Epilogue = ScaledEpilogueBias; template typename Epilogue> @@ -304,12 +309,13 @@ struct sm90_fp8_config_default { static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; template ()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; template ()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; -} +} // namespace // TODO hardcode types here? -void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets) { - +void cutlass_grouped_mm_sm90( + torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, torch::Tensor const& b_offsets) { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); // int32_t m = a.size(1); - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_fp8_config_default< + ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogue>::Cutlass3xGemm; // using Cutlass3xGemmM64 = - // typename sm90_fp8_config_M64::Cutlass3xGemm; + // typename sm90_fp8_config_M64::Cutlass3xGemm; // using Cutlass3xGemmM128 = - // typename sm90_fp8_config_M128::Cutlass3xGemm; - + // typename sm90_fp8_config_M128::Cutlass3xGemm; // // uint32_t const m = a.size(0); // uint32_t const mp2 = @@ -373,14 +378,16 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, // if (mp2 <= 64) { // // m in [1, 64] - // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // cutlass_group_gemm_caller(out, a, b, a_scales, + // b_scales); // } else if (mp2 <= 128) { // // m in (64, 128] - // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // cutlass_group_gemm_caller(out, a, b, a_scales, + // b_scales); // } else { // // m in (128, inf) - cutlass_group_gemm_caller(out, a, b, problem_sizes, - out_offsets, a_offsets, b_offsets, a_scales, b_scales); + cutlass_group_gemm_caller( + out, a, b, problem_sizes, out_offsets, a_offsets, b_offsets, a_scales, + b_scales); // } - } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 78225f9b0db0a..961437893dee0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -28,13 +28,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets); +void cutlass_grouped_mm_sm90( + torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, torch::Tensor const& b_offsets); #endif diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index a97c8f307df32..563a3f433d98b 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,11 +2,11 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import random from typing import Optional, Type import pytest import torch -import random from tests.kernels.utils import opcheck from vllm import _custom_ops as ops @@ -455,41 +455,43 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + # TODO fix scales @pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) @pytest.mark.parametrize("num_groups", [10]) -@pytest.mark.parametrize("per_act_token", [False])# [True, False]) -@pytest.mark.parametrize("per_out_ch", [True])# [True, False]) -@pytest.mark.parametrize("use_bias", [False])# [True, False]) +@pytest.mark.parametrize("per_act_token", [False]) # [True, False]) +@pytest.mark.parametrize("per_out_ch", [True]) # [True, False]) +@pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, - per_act_token: bool, - per_out_ch: bool, use_bias: bool): + per_act_token: bool, per_out_ch: bool, + use_bias: bool): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. device = "cuda" out_dtype = torch.half - alignment = 16 # 128 // 8 + alignment = 16 # 128 // 8 problem_sizes = torch.empty((num_groups, 3), device="cpu") - offsets_a = torch.empty((num_groups), device="cpu") - offsets_b = torch.empty((num_groups), device="cpu") - offsets_c = torch.empty((num_groups), device="cpu") + offsets_a = torch.empty((num_groups), device="cpu", dtype=torch.int32) + offsets_b = torch.empty((num_groups), device="cpu", dtype=torch.int32) + offsets_c = torch.empty((num_groups), device="cpu", dtype=torch.int32) tot_a = 0 tot_b = 0 tot_c = 0 + m = alignment * random.randint(1, 64) + n = alignment * random.randint(1, 64) + k = alignment * random.randint(1, 64) for g in range(num_groups): - m = alignment * random.randint(1, 64) - n = alignment * random.randint(1, 64) - k = alignment * random.randint(1, 64) tot_a += m tot_b += k tot_c += m - offsets_a[g] = m * k - offsets_b[g] = k * n - offsets_c[g] = m * n + print(m, n, k) + offsets_a[g] = g * m * k + offsets_b[g] = g * k * n + offsets_c[g] = g * m * n problem_sizes[g][0] = m problem_sizes[g][1] = n problem_sizes[g][2] = k @@ -497,32 +499,67 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, a = to_fp8(torch.randn((tot_a, k), device=device)) b = to_fp8(torch.randn((tot_b, n), device=device).t()) c = torch.zeros((tot_c, n), device=device).to(out_dtype) + baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) - print(tot_a, tot_b, tot_c) + # print(a) + # print(b) - print(a.stride(), b.stride(), c.stride()) + # print(offsets_a) + # print(offsets_b) + # print(offsets_c) + # print(tot_a, tot_b, tot_c) + + # print(a.stride(), b.stride(), c.stride()) # m_a_scales = m if per_act_token else 1 # n_b_scales = n if per_out_ch else 1 - scale_a = (torch.randn((tot_a if per_act_token else num_groups), - device=device, - dtype=torch.float32)) - scale_b = (torch.randn((tot_b if per_act_token else num_groups), - device=device, - dtype=torch.float32)) + # scale_a = (torch.randn((tot_a if per_act_token else num_groups), + # device=device, + # dtype=torch.float32)) + # scale_b = (torch.randn((tot_b if per_act_token else num_groups), + # device=device, + # dtype=torch.float32)) + + scale_a = (torch.ones((tot_a if per_act_token else num_groups), + device=device, + dtype=torch.float32)) + scale_b = (torch.ones((tot_b if per_act_token else num_groups), + device=device, + dtype=torch.float32)) + # if use_bias: # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 # else: # bias = None + print(a) + # TODO strides we can get later the same way as in scaled_mm_c3x.cu torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, offsets_c, offsets_a, offsets_b) - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) + # print(a.dtype) + # print(a) + + # torch.set_printoptions(profile='full') + # # print(c[2*m:3*m]) + # print(torch.max(c, dim=1)) + # print(torch.max(c, dim=0)) print(c) + for g in range(num_groups): + baseline[g * m:(g + 1) * m] = baseline_scaled_mm( + a[g * m:(g + 1) * m], + b.t()[g * k:(g + 1) * k], + scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], + scale_b[g * k:(g + 1) * k] if per_act_token else scale_b[g], + out_dtype, None) + print(baseline[g * m:(g + 1) * m]) + print(c[g * m:(g + 1) * m]) + print("*") + # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) # opcheck(torch.ops._C.cutlass_scaled_mm, From c570c69ed80d0f7e2a2be27ef1f931497bc3e589 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 11 Dec 2024 14:41:46 +0000 Subject: [PATCH 4/9] Working for identical sizes Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_gemm_test.cu | 167 +++++++++--------- tests/kernels/test_cutlass.py | 62 ++++--- 2 files changed, 118 insertions(+), 111 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index db86bd1a4b466..03d23c7739691 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -129,6 +129,31 @@ struct ItemDeleter { } }; +template +cutlass::platform::unique_ptr> make_device_ptr( + std::vector& data_host) { + T* data_device; + int count = data_host.size(); + cudaMalloc(&data_device, count * sizeof(T)); + cudaMemcpy(data_device, data_host.data(), count * sizeof(T), + cudaMemcpyHostToDevice); + return cutlass::platform::unique_ptr>(data_device); +} + +/////////////// +template +void print(const TupType& _tup, std::index_sequence) { + std::cout << "("; + (..., (std::cout << (I == 0 ? "" : ", ") << std::get(_tup))); + std::cout << ")\n"; +} + +template +void print(const std::tuple& _tup) { + print(_tup, std::make_index_sequence()); +} +//////////// + template void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, @@ -142,46 +167,67 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, using ElementAcc = float; int groups = problem_sizes.size(0); - std::vector a_ptrs_host(groups); - std::vector b_ptrs_host(groups); - std::vector c_ptrs_host(groups); + std::vector a_ptrs_host(groups); + std::vector b_ptrs_host(groups); + std::vector c_ptrs_host(groups); std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = - static_cast(a.data_ptr()) + a_offsets[g].item(); - b_ptrs_host.at(g) = - static_cast(b.data_ptr()) + b_offsets[g].item(); - c_ptrs_host.at(g) = - static_cast(out.data_ptr()) + out_offsets[g].item(); + a_ptrs_host.at(g) = static_cast(a.data_ptr()) + + a_offsets[g].item(); + b_ptrs_host.at(g) = static_cast(b.data_ptr()) + + b_offsets[g].item(); + c_ptrs_host.at(g) = static_cast(out.data_ptr()) + + out_offsets[g].item(); d_ptrs_host.at(g) = static_cast(out.data_ptr()) + out_offsets[g].item(); - printf("%d %d %d\n", a_offsets[g].item(), + printf("off: %d %d %d\n", a_offsets[g].item(), b_offsets[g].item(), out_offsets[g].item()); } using GemmKernel = typename Gemm::GemmKernel; - using StrideA = typename GemmKernel::InternalStrideA; - using StrideB = typename GemmKernel::InternalStrideB; - using StrideC = typename GemmKernel::InternalStrideC; - // using StrideD = typename GemmKernel::InternalStrideD; + // using StrideA = typename GemmKernel::InternalStrideA; + // using StrideB = typename GemmKernel::InternalStrideB; + // using StrideC = typename GemmKernel::InternalStrideC; + // // using StrideD = typename GemmKernel::InternalStrideD; - std::vector a_stride_host(groups); - std::vector b_stride_host(groups); - std::vector c_stride_host(groups); + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); - for (int g = 0; g < groups; ++g) { - int32_t m = problem_sizes[g][0].item(); - int32_t n = problem_sizes[g][1].item(); - int32_t k = problem_sizes[g][2].item(); - a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, - // row - b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, - // col - c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, - // row - } + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = + typename GemmKernel::InternalStrideC; // typename Gemm::StrideC; + + // StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + // StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + // StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + std::vector a_stride_host(groups, StrideA{lda, Int<1>{}, Int<0>{}}); + std::vector b_stride_host(groups, StrideB{ldb, Int<1>{}, Int<0>{}}); + std::vector c_stride_host(groups, StrideC{ldc, Int<1>{}, Int<0>{}}); + + printf("a: "); + print(a_stride_host[0]); + printf("\nb: "); + print(b_stride_host[0]); + printf("\nc: "); + print(c_stride_host[0]); + printf("\n"); + + // for (int g = 0; g < groups; ++g) { + // int32_t m = problem_sizes[g][0].item(); + // int32_t n = problem_sizes[g][1].item(); + // int32_t k = problem_sizes[g][2].item(); + // a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, + // // row + // b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, + // // col + // c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, + // // row + // } cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with @@ -200,16 +246,11 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, int32_t n = problem_sizes[g][1].item(); int32_t k = problem_sizes[g][2].item(); problem_sizes_host.push_back({m, n, k}); + printf("mnk: %d, %d, %d\n", m, n, k); } - SingleProblemShape* problem_sizes_device; - int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); - cudaMalloc(&problem_sizes_device, problem_sizes_size); - cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), - groups * sizeof(SingleProblemShape), cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - problem_sizes_ptr(problem_sizes_device); + auto problem_sizes_ptr = + make_device_ptr(problem_sizes_host); ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; @@ -221,54 +262,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); - const ElementAB** a_ptrs_device; - cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups * sizeof(ElementAB*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - a_ptrs_ptr(a_ptrs_device); - - const ElementAB** b_ptrs_device; - cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups * sizeof(ElementAB*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - b_ptrs_ptr(b_ptrs_device); - - const ElementC** c_ptrs_device; - cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups * sizeof(ElementC*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - c_ptrs_ptr(c_ptrs_device); - - ElementC** d_ptrs_device; - cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups * sizeof(ElementC*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> d_ptrs_ptr( - d_ptrs_device); - - StrideA* a_stride_device; - cudaMalloc(&a_stride_device, groups * sizeof(StrideA)); - cudaMemcpy(a_stride_device, a_stride_host.data(), groups * sizeof(StrideA), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> a_stride_ptr( - a_stride_device); + auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); + auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); + auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - StrideB* b_stride_device; - cudaMalloc(&b_stride_device, groups * sizeof(StrideB)); - cudaMemcpy(b_stride_device, b_stride_host.data(), groups * sizeof(StrideB), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> b_stride_ptr( - b_stride_device); - - StrideC* c_stride_device; - cudaMalloc(&c_stride_device, groups * sizeof(StrideC)); - cudaMemcpy(c_stride_device, c_stride_host.data(), groups * sizeof(StrideC), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> c_stride_ptr( - c_stride_device); + auto a_stride_ptr = make_device_ptr(a_stride_host); + auto b_stride_ptr = make_device_ptr(b_stride_host); + auto c_stride_ptr = make_device_ptr(c_stride_host); typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 563a3f433d98b..1532feba47d6a 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -62,6 +62,7 @@ def baseline_scaled_mm(a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None) -> torch.Tensor: + print(a.shape, b.shape, scale_a.shape, scale_b.shape) output = (scale_a * (scale_b * (torch.mm( a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) if bias is not None: @@ -458,9 +459,9 @@ def test_cutlass_support_opcheck(): # TODO fix scales @pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) -@pytest.mark.parametrize("num_groups", [10]) -@pytest.mark.parametrize("per_act_token", [False]) # [True, False]) -@pytest.mark.parametrize("per_out_ch", [True]) # [True, False]) +@pytest.mark.parametrize("num_groups", [1, 4, 10]) +@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) @pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") @@ -486,7 +487,7 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, k = alignment * random.randint(1, 64) for g in range(num_groups): tot_a += m - tot_b += k + tot_b += n tot_c += m print(m, n, k) offsets_a[g] = g * m * k @@ -497,7 +498,13 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, problem_sizes[g][2] = k a = to_fp8(torch.randn((tot_a, k), device=device)) - b = to_fp8(torch.randn((tot_b, n), device=device).t()) + + b_float = torch.randn((tot_b, k), device=device) + # for g in range(num_groups): + # b_float[g * k:(g + 1) * k] = torch.full((k, n), g + 1) + # print(b_float) + + b = to_fp8(b_float.t()) c = torch.zeros((tot_c, n), device=device).to(out_dtype) baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) @@ -511,29 +518,19 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, # print(a.stride(), b.stride(), c.stride()) - # m_a_scales = m if per_act_token else 1 - # n_b_scales = n if per_out_ch else 1 - - # scale_a = (torch.randn((tot_a if per_act_token else num_groups), - # device=device, - # dtype=torch.float32)) - # scale_b = (torch.randn((tot_b if per_act_token else num_groups), - # device=device, - # dtype=torch.float32)) - - scale_a = (torch.ones((tot_a if per_act_token else num_groups), - device=device, - dtype=torch.float32)) - scale_b = (torch.ones((tot_b if per_act_token else num_groups), - device=device, - dtype=torch.float32)) + scale_a = (torch.randn(((m, 1) if per_act_token else (1, 1)), + device=device, + dtype=torch.float32)) + scale_b = (torch.randn(((1, n) if per_out_ch else (1, 1)), + device=device, + dtype=torch.float32)) # if use_bias: # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 # else: # bias = None - print(a) + # print(a) # TODO strides we can get later the same way as in scaled_mm_c3x.cu torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, @@ -547,20 +544,29 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, # # print(c[2*m:3*m]) # print(torch.max(c, dim=1)) # print(torch.max(c, dim=0)) - print(c) + # print(c) for g in range(num_groups): + print(a[g * m:(g + 1) * m].shape, b[:, g * n:(g + 1) * n].shape) baseline[g * m:(g + 1) * m] = baseline_scaled_mm( a[g * m:(g + 1) * m], - b.t()[g * k:(g + 1) * k], - scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], - scale_b[g * k:(g + 1) * k] if per_act_token else scale_b[g], - out_dtype, None) + b[:, g * n:(g + 1) * n], + # scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], + # # scale_b[:, g * n:(g + 1) * n] if per_out_ch else scale_b[:, g], + # scale_b[g], + scale_a, + scale_b, + out_dtype, + None) print(baseline[g * m:(g + 1) * m]) print(c[g * m:(g + 1) * m]) print("*") - # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) + # print(baseline) + # print(c) + + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) # opcheck(torch.ops._C.cutlass_scaled_mm, # (out, a, b, scale_a, scale_b, bias)) From 6ed63f2ebae2d2d6742cc08c855ac8e5b6eb7cd1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 17 Dec 2024 16:41:45 +0000 Subject: [PATCH 5/9] Grouped gemm working Co-authored-by: Lucas Wilkinson Signed-off-by: ElizaWszola --- .../broadcast_load_epilogue_array_c3x.hpp | 464 ++++++++++++++++++ .../epilogue/broadcast_load_epilogue_c3x.hpp | 5 + .../epilogue/scaled_mm_epilogues_c3x.hpp | 64 +++ csrc/ops.h | 12 +- .../cutlass_w8a8/grouped_gemm_test.cu | 224 ++++----- .../cutlass_w8a8/scaled_mm_entry.cu | 26 +- csrc/torch_bindings.cpp | 8 +- tests/kernels/test_cutlass.py | 155 ++---- 8 files changed, 704 insertions(+), 254 deletions(-) create mode 100644 csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 0000000000000..e652179718c95 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,464 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + if (threadIdx.x ==128){ + printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + } + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + // if (threadIdx.x ==128){ + // printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + // } + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 58b1e8ff159fb..9f049efd07b46 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -422,6 +422,11 @@ struct Sm90ColOrScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + if (threadIdx.x ==128){ + printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..ad7c45a076e68 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,4 +1,5 @@ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" /* This file defines custom epilogues for fusing channel scales, token scales, @@ -45,6 +46,16 @@ struct ScaledEpilogueBase { 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + template + using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. @@ -72,6 +83,15 @@ struct ScaledEpilogueBase { std::is_same_v>); return Arguments{data_ptr}; } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } + }; /* @@ -312,4 +332,48 @@ struct ScaledEpilogueBiasAzpToken } }; +/* +TODO document +This is an epilogue with ptr arrays to a_scales and b_scales +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(const float* const* a_scales_ptr, + const float* const* b_scales_ptr, + bool a_col_broadcast, + bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor(a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor(b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + }; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index fce4346fa4218..b655d3bfab58a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,13 +145,11 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets); +void cutlass_grouped_mm(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 03d23c7739691..c9d299c111304 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -38,11 +38,6 @@ using namespace cute; namespace { -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. template struct enable_sm90_or_later : Kernel { template @@ -54,19 +49,13 @@ struct enable_sm90_or_later : Kernel { }; using ProblemShape = - cutlass::gemm::GroupProblemShape>; // - // per group -using ElementAB_Type = - cutlass::float_e4m3_t; // Element type for A matrix operand -// using ElementB = cutlass::float_e4m3_t; // -// Element type for B matrix operand + cutlass::gemm::GroupProblemShape>; +using ElementAB_Type = cutlass::float_e4m3_t; using ElementC_Type = cutlass::half_t; -// Core kernel configurations -using ElementAccumulator = float; // Element type for internal accumulation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -154,129 +143,109 @@ void print(const std::tuple& _tup) { } //////////// -template -void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets, - EpilogueArgs&&... epilogue_params) { +template +void cutlass_group_gemm_caller(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; - using ElementAcc = float; - int groups = problem_sizes.size(0); + int groups = (int)a_tensors.size(); + TORCH_CHECK((int)b_tensors.size() == groups, + "Number of B tensors must match number of groups."); + TORCH_CHECK((int)out_tensors.size() == groups, + "Number of output tensors must match number of groups."); + std::vector a_ptrs_host(groups); std::vector b_ptrs_host(groups); std::vector c_ptrs_host(groups); std::vector d_ptrs_host(groups); + std::vector a_scales_ptrs_host(groups); + std::vector b_scales_ptrs_host(groups); + + std::vector problem_sizes_host; + problem_sizes_host.reserve(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = static_cast(a.data_ptr()) + - a_offsets[g].item(); - b_ptrs_host.at(g) = static_cast(b.data_ptr()) + - b_offsets[g].item(); - c_ptrs_host.at(g) = static_cast(out.data_ptr()) + - out_offsets[g].item(); - d_ptrs_host.at(g) = - static_cast(out.data_ptr()) + out_offsets[g].item(); - printf("off: %d %d %d\n", a_offsets[g].item(), - b_offsets[g].item(), out_offsets[g].item()); + a_ptrs_host[g] = + reinterpret_cast(a_tensors[g].data_ptr()); + b_ptrs_host[g] = + reinterpret_cast(b_tensors[g].data_ptr()); + c_ptrs_host[g] = + reinterpret_cast(out_tensors[g].data_ptr()); + d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); + a_scales_ptrs_host[g] = + reinterpret_cast(a_scales[g].data_ptr()); + b_scales_ptrs_host[g] = + reinterpret_cast(b_scales[g].data_ptr()); + + int64_t m = a_tensors[g].size(0); + int64_t k = a_tensors[g].size(1); + + int64_t k_b = b_tensors[g].size(0); + int64_t n = b_tensors[g].size(1); + + TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, + " while B has k=", k_b); + + // Optionally, verify output shape matches (m,n) + TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, + "Output tensor shape does not match m,n from A,B: ", "Got ", + out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); + + problem_sizes_host.push_back({(int)m, (int)n, (int)k}); } using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; - // using StrideA = typename GemmKernel::InternalStrideA; - // using StrideB = typename GemmKernel::InternalStrideB; - // using StrideC = typename GemmKernel::InternalStrideC; - // // using StrideD = typename GemmKernel::InternalStrideD; + std::vector a_stride_host(groups); + std::vector b_stride_host(groups); + std::vector c_stride_host(groups); - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); + for (int32_t g = 0; g < groups; ++g) { + int64_t lda = a_tensors[g].stride(0); // row-major (m x k) + int64_t ldb = b_tensors[g].stride(1); // column-major (k x n) + int64_t ldc = out_tensors[g].stride(0); // row-major (m x n) - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; - using StrideC = - typename GemmKernel::InternalStrideC; // typename Gemm::StrideC; - - // StrideA a_stride{lda, Int<1>{}, Int<0>{}}; - // StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; - // StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - std::vector a_stride_host(groups, StrideA{lda, Int<1>{}, Int<0>{}}); - std::vector b_stride_host(groups, StrideB{ldb, Int<1>{}, Int<0>{}}); - std::vector c_stride_host(groups, StrideC{ldc, Int<1>{}, Int<0>{}}); - - printf("a: "); - print(a_stride_host[0]); - printf("\nb: "); - print(b_stride_host[0]); - printf("\nc: "); - print(c_stride_host[0]); - printf("\n"); - - // for (int g = 0; g < groups; ++g) { - // int32_t m = problem_sizes[g][0].item(); - // int32_t n = problem_sizes[g][1].item(); - // int32_t k = problem_sizes[g][2].item(); - // a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, - // // row - // b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, - // // col - // c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, - // // row - // } + a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; + b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; + c_stride_host[g] = StrideC{ldc, Int<1>{}, Int<0>{}}; + } cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; - - std::vector problem_sizes_host; - problem_sizes_host.reserve(groups); - for (int32_t g = 0; g < groups; ++g) { - int32_t m = problem_sizes[g][0].item(); - int32_t n = problem_sizes[g][1].item(); - int32_t k = problem_sizes[g][2].item(); - problem_sizes_host.push_back({m, n, k}); - printf("mnk: %d, %d, %d\n", m, n, k); - } - - auto problem_sizes_ptr = - make_device_ptr(problem_sizes_host); + auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; - // ElementAB* a_host_print; - // int numel = a.numel(); - // cudaMalloc(&a_host_print, groups * sizeof(ElementAB)); - // cudaMemcpy(a_host_print, static_cast(a.data_ptr()), numel* - // sizeof(ElementAB), cudaMemcpyDeviceToHost); - // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* - // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); + auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); + auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); + auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); - auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); - auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); + auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); + auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); - auto a_stride_ptr = make_device_ptr(a_stride_host); - auto b_stride_ptr = make_device_ptr(b_stride_host); - auto c_stride_ptr = make_device_ptr(c_stride_host); + auto a_stride_ptr = make_device_ptr(a_stride_host); + auto b_stride_ptr = make_device_ptr(b_stride_host); + auto c_stride_ptr = make_device_ptr(c_stride_host); typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), + a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), + a_scales[0].numel() != 1, b_scales[0].numel() != 1), c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; @@ -284,30 +253,26 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, epilogue_args, hw_info}; - // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; + // std::cout << "gemm_op.can_implement(args): " + // << (int)gemm_op.can_implement(args) << std::endl; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors[0].device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - + auto stream = at::cuda::getCurrentCUDAStream(a_tensors[0].device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } -// typedef InType = cutlass::float_e4m3_t; -// typedef OutType = torch::half; - template typename Epilogue> struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); + static_assert(std::is_same_v); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = @@ -354,18 +319,23 @@ struct sm90_fp8_config_M64 { } // namespace -// TODO hardcode types here? -void cutlass_grouped_mm_sm90( - torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, torch::Tensor const& b_offsets) { - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - // int32_t m = a.size(1); +void cutlass_grouped_mm_sm90(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { + TORCH_CHECK(a_tensors.size() > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size() > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size() > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors[0].dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors[0].dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); using Cutlass3xGemmDefault = typename sm90_fp8_config_default< - ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogue>::Cutlass3xGemm; + ElementAB_Type, ElementC_Type, + vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; // using Cutlass3xGemmM64 = // typename sm90_fp8_config_M64::Cutlass3xGemm; @@ -388,7 +358,5 @@ void cutlass_grouped_mm_sm90( // } else { // // m in (128, inf) cutlass_group_gemm_caller( - out, a, b, problem_sizes, out_offsets, a_offsets, b_offsets, a_scales, - b_scales); - // } + out_tensors, a_tensors, b_tensors, a_scales, b_scales); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 961437893dee0..eb5d09a6de7ba 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -28,11 +28,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm_sm90( - torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, torch::Tensor const& b_offsets); +void cutlass_grouped_mm_sm90(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales); #endif @@ -158,15 +158,13 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets) { - cutlass_grouped_mm_sm90(out, a, b, a_scales, b_scales, problem_sizes, - out_offsets, a_offsets, b_offsets); +void cutlass_grouped_mm(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { + cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, + b_scales); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a10c661b22a6a..22a1a1a4ae080 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -313,10 +313,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( - "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " - " Tensor b_scales, Tensor problem_sizes, " - " Tensor out_offsets, Tensor a_offsets, " - " Tensor b_offsets) -> ()"); + "cutlass_grouped_mm(Tensor![] out_tensors," + " Tensor[] a_tensors," + " Tensor[] b_tensors, Tensor[] a_scales, " + " Tensor[] b_scales) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); // Mamba selective scan kernel diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 1532feba47d6a..4c909669aa5d3 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -457,116 +457,69 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) -# TODO fix scales -@pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) -@pytest.mark.parametrize("num_groups", [1, 4, 10]) +@pytest.mark.parametrize("num_groups", [8]) @pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) @pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, - per_act_token: bool, per_out_ch: bool, - use_bias: bool): +def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): - # Test for a cutlass kernel with per-token activation quantization - # and per-output channel weight quantization. + # Device and dtype setup device = "cuda" out_dtype = torch.half - alignment = 16 # 128 // 8 - problem_sizes = torch.empty((num_groups, 3), device="cpu") - offsets_a = torch.empty((num_groups), device="cpu", dtype=torch.int32) - offsets_b = torch.empty((num_groups), device="cpu", dtype=torch.int32) - offsets_c = torch.empty((num_groups), device="cpu", dtype=torch.int32) - tot_a = 0 - tot_b = 0 - tot_c = 0 - m = alignment * random.randint(1, 64) - n = alignment * random.randint(1, 64) - k = alignment * random.randint(1, 64) - for g in range(num_groups): - tot_a += m - tot_b += n - tot_c += m - print(m, n, k) - offsets_a[g] = g * m * k - offsets_b[g] = g * k * n - offsets_c[g] = g * m * n - problem_sizes[g][0] = m - problem_sizes[g][1] = n - problem_sizes[g][2] = k - - a = to_fp8(torch.randn((tot_a, k), device=device)) - - b_float = torch.randn((tot_b, k), device=device) - # for g in range(num_groups): - # b_float[g * k:(g + 1) * k] = torch.full((k, n), g + 1) - # print(b_float) - - b = to_fp8(b_float.t()) - c = torch.zeros((tot_c, n), device=device).to(out_dtype) - baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) - - # print(a) - # print(b) - - # print(offsets_a) - # print(offsets_b) - # print(offsets_c) - # print(tot_a, tot_b, tot_c) - - # print(a.stride(), b.stride(), c.stride()) - - scale_a = (torch.randn(((m, 1) if per_act_token else (1, 1)), - device=device, - dtype=torch.float32)) - scale_b = (torch.randn(((1, n) if per_out_ch else (1, 1)), - device=device, - dtype=torch.float32)) - - # if use_bias: - # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 - # else: - # bias = None - - # print(a) - - # TODO strides we can get later the same way as in scaled_mm_c3x.cu - torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, - offsets_c, offsets_a, offsets_b) - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) - - # print(a.dtype) - # print(a) - - # torch.set_printoptions(profile='full') - # # print(c[2*m:3*m]) - # print(torch.max(c, dim=1)) - # print(torch.max(c, dim=0)) - # print(c) + # Create separate A, B, C tensors for each group + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + out_tensors = [] + baseline_tensors = [] + alignment = 16 # 128 // 8 + # For variation, each group g has dimensions + # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) for g in range(num_groups): - print(a[g * m:(g + 1) * m].shape, b[:, g * n:(g + 1) * n].shape) - baseline[g * m:(g + 1) * m] = baseline_scaled_mm( - a[g * m:(g + 1) * m], - b[:, g * n:(g + 1) * n], - # scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], - # # scale_b[:, g * n:(g + 1) * n] if per_out_ch else scale_b[:, g], - # scale_b[g], - scale_a, - scale_b, - out_dtype, - None) - print(baseline[g * m:(g + 1) * m]) - print(c[g * m:(g + 1) * m]) + m_g = alignment * random.randint(1, 64) + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + + m_a_scales = m_g if per_act_token else 1 + n_b_scales = n_g if per_out_ch else 1 + + print(m_g, n_g, k_g) + + # Create group-specific A and B (FP8) and output (FP16/FP32) + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) + # Set up A/B scales + scale_a = torch.randn((m_a_scales, 1), + device=device, + dtype=torch.float32) + scale_b = torch.randn((1, n_b_scales), + device=device, + dtype=torch.float32) + + a_tensors.append(a_g) + b_tensors.append(b_g) + out_tensors.append(c_g) + a_scales_tensors.append(scale_a) + b_scales_tensors.append(scale_b) + + # Compute baseline result for this group + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, + None) + baseline_tensors.append(baseline_g) + + torch.ops._C.cutlass_grouped_mm(out_tensors, a_tensors, b_tensors, + a_scales_tensors, b_scales_tensors) + + # Validate each group's result against the baseline + for c_g, baseline_g in zip(out_tensors, baseline_tensors): + print(baseline_g) + print(c_g) print("*") - - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) - # print(baseline) - # print(c) - - torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) - - # opcheck(torch.ops._C.cutlass_scaled_mm, - # (out, a, b, scale_a, scale_b, bias)) + torch.testing.assert_close(c_g, baseline_g, rtol=1e-2, atol=5e-2) From e2b1fc05479311f3efac5daf7d34af0b9626ee3c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 17 Dec 2024 16:53:53 +0000 Subject: [PATCH 6/9] Small cleanup Signed-off-by: ElizaWszola --- CMakeLists.txt | 2 +- .../broadcast_load_epilogue_array_c3x.hpp | 7 --- ...grouped_gemm_test.cu => grouped_mm_c3x.cu} | 45 ++----------------- tests/kernels/test_cutlass.py | 11 ++--- 4 files changed, 10 insertions(+), 55 deletions(-) rename csrc/quantization/cutlass_w8a8/{grouped_gemm_test.cu => grouped_mm_c3x.cu} (90%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d6185e756338..c19812ab54914 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,7 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu") + "csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index e652179718c95..5c1d6e3f46be0 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -241,10 +241,6 @@ struct Sm90RowOrScalarBroadcastArray { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - if (threadIdx.x ==128){ - printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - } - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem), @@ -435,9 +431,6 @@ struct Sm90ColOrScalarBroadcastArray { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - // if (threadIdx.x ==128){ - // printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu similarity index 90% rename from csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu rename to csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index c9d299c111304..b08d67d046643 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -5,7 +5,7 @@ #include "cutlass/cutlass.h" -// TODO let's see which of these we'll need +// TODO clean up the includes we no longer need #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" @@ -26,10 +26,6 @@ #include "common.hpp" -// get rid of these? -// #include "helper.h" -// using namespace cute; - using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 @@ -129,20 +125,6 @@ cutlass::platform::unique_ptr> make_device_ptr( return cutlass::platform::unique_ptr>(data_device); } -/////////////// -template -void print(const TupType& _tup, std::index_sequence) { - std::cout << "("; - (..., (std::cout << (I == 0 ? "" : ", ") << std::get(_tup))); - std::cout << ")\n"; -} - -template -void print(const std::tuple& _tup) { - print(_tup, std::make_index_sequence()); -} -//////////// - template void cutlass_group_gemm_caller(c10::List const& out_tensors, c10::List const& a_tensors, @@ -242,6 +224,8 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; + // Currently, we are only able to do broadcast on either all or none a_scales + // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), @@ -255,8 +239,6 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; - // std::cout << "gemm_op.can_implement(args): " - // << (int)gemm_op.can_implement(args) << std::endl; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); @@ -336,27 +318,6 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, using Cutlass3xGemmDefault = typename sm90_fp8_config_default< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - // using Cutlass3xGemmM64 = - // typename sm90_fp8_config_M64::Cutlass3xGemm; - // using Cutlass3xGemmM128 = - // typename sm90_fp8_config_M128::Cutlass3xGemm; - - // // uint32_t const m = a.size(0); - // uint32_t const mp2 = - // std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - // if (mp2 <= 64) { - // // m in [1, 64] - // cutlass_group_gemm_caller(out, a, b, a_scales, - // b_scales); - // } else if (mp2 <= 128) { - // // m in (64, 128] - // cutlass_group_gemm_caller(out, a, b, a_scales, - // b_scales); - // } else { - // // m in (128, inf) cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales); } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 4c909669aa5d3..445a06f57a965 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -457,10 +457,11 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) +# TODO add bias @pytest.mark.parametrize("num_groups", [8]) -@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) -@pytest.mark.parametrize("use_bias", [False]) # [True, False]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("use_bias", [False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, @@ -479,9 +480,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, baseline_tensors = [] alignment = 16 # 128 // 8 - # For variation, each group g has dimensions + # For variation, each group has dimensions # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) - for g in range(num_groups): + for _ in range(num_groups): m_g = alignment * random.randint(1, 64) n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64) From acfd3ef49ca32c9a70d22d9e592671b51c0212a3 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 13 Jan 2025 15:36:33 +0000 Subject: [PATCH 7/9] Benchmark grouped cutlass against bfloat16 torch.mm Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 199 ++++++++++++++++++ benchmarks/kernels/benchmark_shapes.py | 21 ++ vllm/_custom_ops.py | 7 + 3 files changed, 227 insertions(+) create mode 100644 benchmarks/kernels/benchmark_grouped_gemm_cutlass.py diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py new file mode 100644 index 0000000000000..be401cec03c66 --- /dev/null +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -0,0 +1,199 @@ +from typing import List, Tuple + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES_MOE + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"] + # "nm-testing/deepseekv2-lite", + # "ibm-granite/granite-3.0-1b-a400m", + # "ibm-granite/granite-3.0-3b-a800m"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +NUM_GROUPS_OPTS = [8] +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] + +def grouped_gemm(a_g_tensors: List[torch.Tensor], + b_g_tensors: List[torch.Tensor], + out_g_tensors: List[torch.Tensor], + a_scales_tensors: List[torch.Tensor], + b_scales_tensors: List[torch.Tensor]): + ops.cutlass_grouped_mm(out_g_tensors, a_g_tensors, b_g_tensors, + a_scales_tensors, b_scales_tensors) + +def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], + b_tensors: List[torch.Tensor], + out_tensors: List[torch.Tensor]): + for g in range(num_groups): + a = a_tensors[g] + b = b_tensors[g] + out = torch.mm(a, b) + out_tensors[g] = out + +def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, + per_act_token: bool, per_out_ch: bool, + mkn: List[Tuple[int, int, int]]): + label = "Quant Matmul" + + sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_groups, per_act_token, + per_out_ch, mkn)) + + print(f"Testing: {sub_label}") + + device = "cuda" + out_dtype = torch.half + + def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + a_tensors = [] + b_tensors = [] + a_g_tensors = [] + b_g_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + out_tensors = [] + out_g_tensors = [] + baseline_tensors = [] + + for g in range(num_groups): + m_g = mkn[g][0] + k_g = mkn[g][1] + n_g = mkn[g][2] + + m_a_scales = m_g if per_act_token else 1 + n_b_scales = n_g if per_out_ch else 1 + + a = torch.randn((m_g, k_g), device=device) + b = torch.randn((n_g, k_g), device=device).t() + c = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) + + a_g = to_fp8(a) + b_g = to_fp8(b) + c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) + + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) + + a_tensors.append(a.to(dtype=torch.bfloat16)) + b_tensors.append(b.to(dtype=torch.bfloat16)) + out_tensors.append(c) + a_g_tensors.append(a_g) + b_g_tensors.append(b_g) + out_g_tensors.append(c_g) + baseline_tensors.append(c_g) + a_scales_tensors.append(scale_a) + b_scales_tensors.append(scale_b) + + globals = { + # Gen params + "a_tensors": a_tensors, + "b_tensors": b_tensors, + "a_g_tensors": a_g_tensors, + "b_g_tensors": b_g_tensors, + "out_g_tensors": out_g_tensors, + "out_tensors": out_tensors, + "baseline_tensors": baseline_tensors, + "a_scales_tensors": a_scales_tensors, + "b_scales_tensors": b_scales_tensors, + "num_groups": num_groups, + # Kernels + "grouped_gemm": grouped_gemm, + "baseline_gemm": baseline_gemm, + } + + min_run_time = 1 + num_warmup = 5 + + # Warmup pytorch + for _ in range(num_warmup): + grouped_gemm(a_g_tensors, b_g_tensors, out_g_tensors, a_scales_tensors, + b_scales_tensors) + + results.append( + benchmark.Timer( + stmt="grouped_gemm(a_g_tensors, b_g_tensors, out_g_tensors, a_scales_tensors, b_scales_tensors)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup pytorch + for _ in range(num_warmup): + baseline_gemm(num_groups, a_tensors, b_tensors, out_tensors) + + results.append( + benchmark.Timer( + stmt= + "output = baseline_gemm(num_groups, a_tensors, b_tensors, out_tensors)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="baseline_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: List[benchmark.Measurement] = [] + + for model in args.models: + for layer in WEIGHT_SHAPES_MOE[model]: + num_groups = layer[0] + size_k = layer[1] + size_n = layer[2] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in DEFAULT_BATCH_SIZES: + mkn = [(size_m, size_k, size_n)] * num_groups + bench_run(results, model, num_groups, per_act_token, + per_out_ch, mkn) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 ... +# +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 4eeeca35a37cc..9550236aa671f 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -73,3 +73,24 @@ [7168, 8192], ], } + +WEIGHT_SHAPES_MOE = { + "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ + [8, 4096, 28672], + [8, 14336, 4096], + ], + "nm-testing/deepseekv2-lite": [ + [64, 2048, 352], + [64, 1408, 256], + [64, 128, 5632], + [64, 88, 4096], + ], + "ibm-granite/granite-3.0-1b-a400m": [ + [32, 1024, 2048], + [32, 1024, 1024], + ], + "ibm-granite/granite-3.0-3b-a800m": [ + [40, 1536, 2048], + [40, 1024, 1536], + ], +} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index afb350591e562..7703ec0d966ee 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -490,6 +490,13 @@ def cutlass_scaled_mm(a: torch.Tensor, return out +def cutlass_grouped_mm(out: List[torch.Tensor], a: List[torch.Tensor], + b: List[torch.Tensor], scale_a: List[torch.Tensor], + scale_b: List[torch.Tensor]) -> torch.Tensor: + torch.ops._C.cutlass_grouped_mm(out, a, b, scale_a, scale_b) + return out + + def cutlass_scaled_mm_azp(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, From f1a56669f59cc1bf138523c1247dde9db7f28e56 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 17 Jan 2025 16:27:58 +0000 Subject: [PATCH 8/9] Start working on fused moe cutlass implementation Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 196 +++++++++++++++++- csrc/cpu/torch_bindings.cpp | 8 + .../epilogue/broadcast_load_epilogue_c3x.hpp | 6 +- csrc/ops.h | 6 + .../cutlass_w8a8/grouped_mm_c3x.cu | 47 +++++ .../cutlass_w8a8/scaled_mm_entry.cu | 16 ++ csrc/torch_bindings.cpp | 7 + tests/kernels/test_cutlass_moe.py | 145 +++++++++++++ 8 files changed, 422 insertions(+), 9 deletions(-) create mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index be401cec03c66..67923262a5855 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_moe, fused_topk, fused_experts) DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"] # "nm-testing/deepseekv2-lite", @@ -17,6 +19,11 @@ PER_ACT_TOKEN_OPTS = [False, True] PER_OUT_CH_OPTS = [False, True] +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + def grouped_gemm(a_g_tensors: List[torch.Tensor], b_g_tensors: List[torch.Tensor], out_g_tensors: List[torch.Tensor], @@ -33,6 +40,30 @@ def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], b = b_tensors[g] out = torch.mm(a, b) out_tensors[g] = out + +def cutlass_fused(a_tensors: List[torch.Tensor], + w1_tensors: List[torch.Tensor], + w2_tensors: List[torch.Tensor], + c1_tensors: List[torch.Tensor], + c2_tensors: List[torch.Tensor], + c2_tensors_fp8: List[torch.Tensor], + c3_tensors: List[torch.Tensor], + a_scales: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + c2_scales: List[torch.Tensor], + num_groups: int): + # output_dtype = c3_tensors[0].dtype + N = c2_tensors[0].shape[1] + ops.cutlass_grouped_mm(c1_tensors, a_tensors, w1_tensors, + a_scales, w1_scales) + # TODO make this work as it should + for idx in range(num_groups): + torch.ops._C.silu_and_mul(c2_tensors[idx], c1_tensors[idx].view(-1, N)) + print(c2_tensors[idx]) + c2_tensors_fp8[idx] = to_fp8(c2_tensors[idx].half()) + ops.cutlass_grouped_mm(c3_tensors, c2_tensors, w2_tensors, + c2_scales, w2_scales) def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, per_act_token: bool, per_out_ch: bool, @@ -47,11 +78,6 @@ def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, device = "cuda" out_dtype = torch.half - - def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) a_tensors = [] b_tensors = [] @@ -142,6 +168,164 @@ def to_fp8(tensor: torch.Tensor): description="baseline_gemm", ).blocked_autorange(min_run_time=min_run_time)) +def bench_run_moe(results: List[benchmark.Measurement], model: str, num_groups: int, + per_act_token: bool, per_out_ch: bool, + mkn: List[Tuple[int, int, int]]): + label = "Quant Matmul" + + sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_groups, per_act_token, + per_out_ch, mkn)) + + print(f"Testing: {sub_label}") + + device = "cuda" + out_dtype = torch.bfloat16 + + def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + m_tot = sum([elem[0] for elem in mkn]) + k_g = mkn[0][1] + n_g = mkn[0][2] + + a_tensors = [] + w1_tensors = [] + w2_tensors = [] + c1_tensors = [] + c2_tensors = [] + c2_tensors_fp8 = [] + c3_tensors = [] + a_scales = [] + w1_scales = [] + w2_scales = [] + c2_scales = [] + + a = torch.randn((m_tot, k_g), device=device, dtype=out_dtype) + w1 = torch.randn((num_groups, 2 * n_g, k_g), device=device, dtype=out_dtype) + w2 = torch.randn((num_groups, k_g, n_g), device=device, dtype=out_dtype) + scored_output = torch.randn((m_tot, num_groups), device="cuda", dtype=out_dtype) + topk = 2 + # triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + + #TODO grouped topk for deepseek + topk_weights, topk_ids = fused_topk(a, scored_output, topk, renormalize=True) + fused_experts(a, w1, w2, topk_weights, topk_ids) + topk_ids_cpu = topk_ids.cpu() + + occurrences = [0] * num_groups + expert_offsets = [0] * (num_groups + 1) + for id in topk_ids_cpu.flatten(): + occurrences[id] += 1 + + for e in range(num_groups): + expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + + print(expert_offsets, m_tot) + + a = torch.randn((m_tot, k_g)) + a_group[0] = a[sorted_token_ids[0]] + + # TODO + # create full input tensor m_tot x k_g x topk + # get shuffle data like sorted_token_ids etc. + # create view + + for g in range(num_groups): + m_g = occurrences[g] + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + w1_g = to_fp8(torch.randn((2 * n_g, k_g), device=device).t()) + w2_g = to_fp8(torch.randn((k_g, n_g), device=device).t()) + c1_g = torch.zeros((m_g, 2 * n_g), device=device, dtype=torch.bfloat16) + c2_g = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) + c2_g_fp8 = to_fp8(torch.zeros((m_g, n_g), device=device)) + c3_g = torch.zeros((m_g, k_g), device=device, dtype=torch.bfloat16) + # m_a_scales = m_g if per_act_token else 1 + # n_b_scales = n_g if per_out_ch else 1 + m_scales = 1 + n2_scales = 1 + k_scales = 1 + scale_a = (torch.randn((m_scales, 1), device=device, + dtype=torch.float32)) + scale_w1 = (torch.randn((n2_scales, 1), device=device, + dtype=torch.float32)) + scale_w2 = (torch.randn((k_scales, 1), device=device, + dtype=torch.float32)) + scale_c2 = (torch.randn((m_scales, 1), device=device, + dtype=torch.float32)) + + a_tensors.append(a_g) + w1_tensors.append(w1_g) + w2_tensors.append(w2_g) + c1_tensors.append(c1_g) + c2_tensors.append(c2_g) + c2_tensors_fp8.append(c2_g_fp8) + c3_tensors.append(c3_g) + a_scales.append(scale_a) + w1_scales.append(scale_w1) + w2_scales.append(scale_w2) + c2_scales.append(scale_c2) + + globals = { + # Gen params + "num_groups": num_groups, + # Grouped gemm params + "a_tensors": a_tensors, + "w1_tensors": w1_tensors, + "w2_tensors": w2_tensors, + "c1_tensors": c1_tensors, + "c2_tensors": c2_tensors, + "c2_tensors_fp8": c2_tensors_fp8, + "c3_tensors": c3_tensors, + "a_scales": a_scales, + "w1_scales": w1_scales, + "w2_scales": w2_scales, + "c2_scales": c2_scales, + # Triton params (fused_moe) + "a": a, + "w1": w1, + "w2": w2, + "scored_output": scored_output, + "topk": topk, + # Kernels + "fused_moe": fused_moe, + "cutlass_fused": cutlass_fused, + } + + min_run_time = 1 + num_warmup = 5 + + # Warmup triton + for _ in range(num_warmup): + fused_moe(a, w1, w2, scored_output, topk, renormalize=False) + + results.append( + benchmark.Timer( + stmt="fused_moe(a, w1, w2, scored_output, topk, renormalize=False)", + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup cutlass + for _ in range(num_warmup): + cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, + c2_tensors_fp8, c3_tensors, a_scales, w1_scales, + w2_scales, c2_scales, num_groups) + + results.append( + benchmark.Timer( + stmt= + "cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, c2_tensors_fp8, c3_tensors, a_scales, w1_scales, w2_scales, c2_scales, num_groups)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="baseline_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): @@ -165,7 +349,7 @@ def main(args): for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = [(size_m, size_k, size_n)] * num_groups - bench_run(results, model, num_groups, per_act_token, + bench_run_moe(results, model, num_groups, per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index a720348dee3e2..96ddcab7cea26 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -125,6 +125,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor out_offsets, Tensor a_offsets, " " Tensor b_offsets) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + + ops.def( + "compute_expert_offsets(Tensor! trg_a_ptrs," + " Tensor! a, Tensor topk_ids," + " Tensor! expert_offsets, SymInt num_experts) -> ()"); + ops.impl("compute_expert_offsets", torch::kCUDA, + &compute_expert_offsets); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 9f049efd07b46..ad33eec9ef8fe 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -424,9 +424,9 @@ struct Sm90ColOrScalarBroadcast { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - if (threadIdx.x ==128){ - printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - } + // if (threadIdx.x ==128){ + // printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/ops.h b/csrc/ops.h index 736a40091f032..d7ec0e0f91283 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -161,6 +161,12 @@ void cutlass_grouped_mm(c10::List const& out_tensors, c10::List const& a_scales, c10::List const& b_scales); +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index caa7edb888a36..835a144aed4b1 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -321,3 +321,50 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales); } + +__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, + cutlass::float_e4m3_t* base_a_ptr, + const int* __restrict__ topk_ids, + int64_t* expert_offsets, + int topk_length) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int64_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + trg_a_ptrs[i] = base_a_ptr + tot_offset; + tot_offset += expert_offsets[i + 1]; + expert_offsets[i + 1] = tot_offset; + } + } +} + +// For a given "a" of size [M,K] performs a permutation of the M rows based +// on the given "perm" indices. +__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, + int const* __restrict__ perm_int_ptr, + cutlass::float_e4m3_t* __restrict__ out_ptr, + int size_m, int size_k, int block_rows) { + // TODO +} + +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts) { + get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(), + (cutlass::float_e4m3_t*)a.data_ptr(), + (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), + topk_ids.numel()); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index e60b64d7797bf..d9d2a91d0659f 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -36,6 +36,13 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, c10::List const& a_scales, c10::List const& b_scales); + +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); + #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -159,6 +166,15 @@ void cutlass_grouped_mm(c10::List const& out_tensors, b_scales); } +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts) { + compute_expert_offsets_caller(trg_a_ptrs, a, topk_ids, expert_offsets, + num_experts); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b862144aa16f5..65d48c7f14659 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -329,6 +329,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor[] b_scales) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + ops.def( + "compute_expert_offsets(Tensor! trg_a_ptrs," + " Tensor! a, Tensor topk_ids," + " Tensor! expert_offsets, SymInt num_experts) -> ()"); + ops.impl("compute_expert_offsets", torch::kCUDA, + &compute_expert_offsets); + // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability ops.def( diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py new file mode 100644 index 0000000000000..083de75e1d34f --- /dev/null +++ b/tests/kernels/test_cutlass_moe.py @@ -0,0 +1,145 @@ +import pytest +import torch +from transformers import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from typing import List + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) +from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +NUM_EXPERTS = [8, 64] +TOP_KS = [2, 6] + +# TODO move to a better file later +# TODO handle scores +def cutlass_moe(a: torch.Tensor, + a_q: torch.Tensor, + a_scale: torch.Tensor, + w1_qs: List[torch.Tensor], + w2_qs: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +): + # TODO look at the code in benchmark_grouped_gemm_cutlass.py + # and get the relevant parts + # (also the fused_moe function) + + num_groups = len(w1_qs) + topk = topk_ids.shape[1] + num_tokens = topk_ids.shape[0] + + # TODO make this GPU only + occurrences = [0] * num_groups + expert_offsets = [0] * (num_groups + 1) + for id in topk_ids.cpu().flatten(): + occurrences[id] += 1 + for e in range(num_groups): + expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + + # TODO duplicate A rows topk times + # compute sorted_token_ids (argsort?) + # shuffle A according to this so each group input is contiguous + + # print(topk_ids) + # print(expert_offsets) + a_map = topk_ids.flatten().argsort() + rep_a_q = a_q.repeat_interleave(topk, dim=0) + + print(a_map) + print(rep_a_q) + + a_q_s = [] + for e in range(num_groups): + a_q_s.append(rep_a_q[a_map[expert_offsets[e]:expert_offsets[e+1]]]) + print(a_q_s) + return + # get a_map and expert_indices on device + + # TODO shuffle rep_a_q according to a_map + # get a_ptrs = a + expert_indices[:-1] + + a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") + expert_offsets = torch.empty((num_groups + 1), dtype=torch.int64, device="cuda") + # TODO might need to call it from inside cutlass code? + # help(ops) + + # print(a_ptrs) + # print(rep_a_q) + print(topk_ids) + # print(expert_offsets) + # print(num_groups) + torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), + expert_offsets, num_groups) + print(a_ptrs) + print(expert_offsets) + +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) +# @pytest.mark.parametrize("n", [128, 2048]) +# @pytest.mark.parametrize("k", [128, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("m", [10]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +def test_cutlass_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + current_platform.seed_everything(7) + + dtype = torch.bfloat16 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + a_q, a_scale = ops.scaled_fp8_quant(a) + + w1_qs = [] + w2_qs = [] + w1_scales = [] + w2_scales = [] + + for expert in range(e): + w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) + w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) + w1_qs.append(w1_q) + w2_qs.append(w2_q) + w1_scales.append(w1_scale) + w2_scales.append(w2_scale) + + # (assume score is a vector of ones for now) + score = torch.ones((m, e), device="cuda", dtype=dtype) + + e_range = torch.full((m, e), 1.0 / e) + topk_ids = torch.multinomial(e_range, topk).int().sort()[0] + topk_weights = torch.rand((m, topk)) + + torch_output = torch_moe(a, w1, w2, score, topk) + cutlass_output = cutlass_moe(a, a_q, a_scale, w1_qs, w2_qs, w1_scales, + w2_scales, topk_weights, topk_ids) + + # torch.testing.assert_close(torch_output, + # cutlass_output, + # atol=2e-2, + # rtol=0) From 6414e317bae3fb36a0871f5c68c6db517e974008 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 20 Jan 2025 23:49:35 +0000 Subject: [PATCH 9/9] Working halfway Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 73 +++++++++++++++---- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 835a144aed4b1..4abb84e3e0bbd 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -163,6 +163,8 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, b_scales_ptrs_host[g] = reinterpret_cast(b_scales[g].data_ptr()); + // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], + // c_ptrs_host[g], d_ptrs_host[g],) int64_t m = a_tensors[g].size(0); int64_t k = a_tensors[g].size(1); @@ -348,23 +350,68 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, } } -// For a given "a" of size [M,K] performs a permutation of the M rows based -// on the given "perm" indices. -__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, - int const* __restrict__ perm_int_ptr, - cutlass::float_e4m3_t* __restrict__ out_ptr, - int size_m, int size_k, int block_rows) { - // TODO -} +// // For a given "a" of size [M,K] performs a permutation of the M rows based +// // on the given "perm" indices. +// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, +// int const* __restrict__ perm_int_ptr, +// cutlass::float_e4m3_t* __restrict__ out_ptr, +// int size_m, int size_k, int block_rows) { +// int start_row = block_rows * blockIdx.x; +// int finish_row = start_row + block_rows; +// if (finish_row > size_m) { +// finish_row = size_m; +// } +// int cur_block_rows = finish_row - start_row; + +// int row_stride = size_k * sizeof(cutlass::float_e4m3_t) / 16; + +// auto permute_row = [&](int row) { +// int iters = size_k / blockDim.x; +// int rest = size_k % blockDim.x; + +// int a_offset = perm_int_ptr[row] * row_stride; +// int out_offset = row * row_stride; + +// cutlass::float_e4m3_t const* a_row_fp8 = a_ptr + a_offset; +// cutlass::float_e4m3_t* out_fp8 = out_ptr + out_offset; + +// int base_k = 0; + +// for (int i = 0; i < iters; i++) { +// int cur_k = base_k + threadIdx.x; +// out_fp8[cur_k] = a_row_fp8[cur_k]; +// base_k += blockDim.x; +// } + +// if (rest) { +// if (threadIdx.x < rest) { +// int cur_k = base_k + threadIdx.x; +// out_fp8[cur_k] = a_row_fp8[cur_k]; +// } +// } +// }; +// } void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const int64_t num_experts) { - get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(), - (cutlass::float_e4m3_t*)a.data_ptr(), - (const int*)topk_ids.data_ptr(), - (int64_t*)expert_offsets.data_ptr(), - topk_ids.numel()); + get_a_expert_offsets<<<1, num_experts>>>( + (cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(), + (cutlass::float_e4m3_t*)a.data_ptr(), + (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), + topk_ids.numel()); } + +// void permute_fp8_rows(torch::Tensor& a_ptr, +// torch::Tensor& perm_ptr, +// torch::Tensor& out_ptr, +// int size_m, int size_k, int topk, int block_rows) { +// permute_fp8_rows_kernel<<>>( +// (cutlass::float_e4m3_t const*)a_ptr.data_ptr(), +// (const int*)perm_ptr.data_ptr(), +// (cutlass::float_e4m3_t const*)out_ptr.data_ptr(), size_m * topk, +// size_k, block_rows); +// }