diff --git a/CMakeLists.txt b/CMakeLists.txt index 51b49a18dddf2..83c8033434f3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -273,15 +273,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # - # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels - # For Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - "csrc/sparse/cutlass/sparse_compressor_c3x.cu" - "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -290,12 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " + message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 sparse or quantized models on " + "later if you intend on running FP8 quantized models on " "Hopper.") else() - message(STATUS "Not building cutlass_c3x as no compatible archs found " + message(STATUS "Not building scaled_mm_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -329,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # + # 2:4 Sparse Kernels + + # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor + # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now). + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + # # Machete kernels diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1e5967bd9bf8b..c1b10b3cf8f58 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,7 +4,8 @@ import json import random import time -from typing import List, Optional +from functools import cache +from typing import Dict, List, Optional, Tuple import torch import uvloop @@ -17,8 +18,11 @@ from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.inputs import TextPrompt +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -28,15 +32,17 @@ class SampleRequest: Attributes: prompt: The input text prompt for the model. - multi_modal_data: Optional dictionary containing multi-modal data (e.g. - images). prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + lora_request: Optional LoRARequest specifying the LoRA to use. """ prompt: str prompt_len: int expected_output_len: int multi_modal_data: Optional[MultiModalDataDict] = None + lora_request: Optional[LoRARequest] = None def _get_prompt_for_image_model(question: str, *, model: str) -> str: @@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str: raise ValueError(f"Unsupported model {model}") +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} + + +def get_random_lora_request( + args: argparse.Namespace +) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: + global lora_tokenizer_cache + lora_id = random.randint(1, args.max_loras) + lora_request = LoRARequest(lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(args.lora_path)) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + return lora_request, lora_tokenizer_cache[lora_id] + + def sample_requests(tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset num_requests: int = args.num_prompts fixed_output_len: Optional[int] = args.output_len @@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] - for data in dataset: + for data in tqdm(dataset, + total=len(filtered_dataset), + desc="sampling requests"): if len(filtered_dataset) == num_requests: break @@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, continue prompt = _get_prompt_for_image_model(question=prompt, model=model) + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Tokenize the prompts and completions. - prompt_token_ids = tokenizer(prompt).input_ids - completion_token_ids = tokenizer(completion).input_ids + prompt_token_ids = request_tokenizer(prompt).input_ids + completion_token_ids = request_tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len @@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, SampleRequest(prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - multi_modal_data=multi_modal_data)) + multi_modal_data=multi_modal_data, + lora_request=lora_request)) return filtered_dataset @@ -146,14 +184,21 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests: Optional[List[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] use_beam_search = False if not use_beam_search: start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) end = time.perf_counter() else: + assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0][2] @@ -185,6 +230,7 @@ async def run_vllm_async( # Add the requests to the engine. prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] + lora_requests: List[Optional[LoRARequest]] = [] for request in requests: prompts.append( TextPrompt(prompt=request.prompt, @@ -197,11 +243,16 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = llm.generate(prompt, sp, request_id=f"test{i}") + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -297,6 +348,14 @@ def main(args: argparse.Namespace): vocab_size = tokenizer.vocab_size requests = [] for _ in range(args.num_prompts): + + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Synthesize a prompt with the given input length. candidate_ids = [ random.randint(0, vocab_size - 1) @@ -305,8 +364,8 @@ def main(args: argparse.Namespace): # As tokenizer may add additional tokens like BOS, we need to try # different lengths to get the desired input length. for _ in range(5): # Max attempts to correct - candidate_prompt = tokenizer.decode(candidate_ids) - tokenized_len = len(tokenizer.encode(candidate_prompt)) + candidate_prompt = request_tokenizer.decode(candidate_ids) + tokenized_len = len(request_tokenizer.encode(candidate_prompt)) if tokenized_len == args.input_len: break @@ -323,7 +382,8 @@ def main(args: argparse.Namespace): requests.append( SampleRequest(prompt=candidate_prompt, prompt_len=args.input_len, - expected_output_len=args.output_len)) + expected_output_len=args.output_len, + lora_request=lora_request)) else: requests = sample_requests(tokenizer, args) @@ -422,6 +482,14 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: @@ -431,6 +499,8 @@ def main(args: argparse.Namespace): assert args.output_len is not None else: assert args.input_len is None + if args.enable_lora: + assert args.lora_path is not None if args.backend == "vllm": if args.hf_max_batch_size is not None: @@ -440,6 +510,9 @@ def main(args: argparse.Namespace): raise ValueError("HF max batch size is required for HF backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") elif args.backend == "mii": if args.dtype != "auto": raise ValueError("dtype must be auto for MII backend.") @@ -452,4 +525,7 @@ def main(args: argparse.Namespace): if args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII " "backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") main(args) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index c69e87999ae71..26f7423fd7455 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" /* diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index fcc17c7727f94..c723adf126422 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" /* diff --git a/csrc/ops.h b/csrc/ops.h index c145e4eda0845..347c502845d8f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -163,6 +163,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, c10::optional const& azp, c10::optional const& bias); +bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); + void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& e, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 75681f7f37820..f2fae4b66d651 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -28,9 +28,9 @@ using namespace cute; /* - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, + Epilogues defined in, + csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp + must contain a public type named EVTCompute of type Sm80EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 8190277997161..123f4359c0d1a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,385 +1,18 @@ -// clang-format will break include orders -// clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include + #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh" + #include "scaled_mm_c3x_sm90_int8_dispatch.cuh" -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "core/math.hpp" -#include "cutlass_extensions/common.hpp" -// clang-format on - -using namespace cute; + #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. */ -namespace { - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; - - struct GemmKernel : public KernelType {}; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - if (mp2 <= 64) { - // m in [1, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_int8_config_M128::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_int8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; - using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - - uint32_t const n = out.size(1); - bool const is_small_n = n < 8192; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 - - if (mp2 <= 32) { - // m in [1, 32] - if (is_small_n) { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else if (mp2 <= 64) { - // m in (32, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - template