diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py new file mode 100644 index 0000000000000..8350e2705141e --- /dev/null +++ b/.buildkite/generate_index.py @@ -0,0 +1,24 @@ +import argparse +import os + +template = """ + + +

Links for vLLM

+ {wheel}
+ + +""" + +parser = argparse.ArgumentParser() +parser.add_argument("--wheel", help="The wheel path.", required=True) +args = parser.parse_args() + +filename = os.path.basename(args.wheel) + +with open("index.html", "w") as f: + print(f"Generated index.html for {args.wheel}") + # cloudfront requires escaping the '+' character + f.write( + template.format(wheel=filename, + wheel_html_escaped=filename.replace("+", "%2B"))) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 64ba1b32fb074..679abf1814aa5 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -1,5 +1,6 @@ steps: - label: "Wait for container to be ready" + key: wait-for-container-image agents: queue: A100 plugins: @@ -10,12 +11,11 @@ steps: command: - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - - wait - - label: "A100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: A100 + depends_on: wait-for-container-image plugins: - kubernetes: podSpec: @@ -49,6 +49,7 @@ steps: # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H200 + depends_on: wait-for-container-image plugins: - docker#v5.12.0: image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT @@ -65,15 +66,15 @@ steps: - VLLM_USAGE_SOURCE - HF_TOKEN - - block: "Run H100 Benchmark" - key: block-h100 - depends_on: ~ + #- block: "Run H100 Benchmark" + #key: block-h100 + #depends_on: ~ - label: "H100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 - depends_on: block-h100 + depends_on: wait-for-container-image plugins: - docker#v5.12.0: image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 2de6fceb0c3fe..51618a2955fb1 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -55,3 +55,18 @@ steps: password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" + + - block: "Build CPU release image" + key: block-cpu-release-image-build + depends_on: ~ + + - label: "Build and publish CPU release image" + depends_on: block-cpu-release-image-build + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION" + env: + DOCKER_BUILDKIT: "1" diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh new file mode 100644 index 0000000000000..4fc6d089cc666 --- /dev/null +++ b/.buildkite/run-gh200-test.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# This script build the GH200 docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Skip the new torch installation during build since we are using the specified version for arm64 in the Dockerfile +python3 use_existing_torch.py + +# Try building the docker image +DOCKER_BUILDKIT=1 docker build . \ + --target vllm-openai \ + --platform "linux/arm64" \ + -t gh200-test \ + --build-arg max_jobs=66 \ + --build-arg nvcc_threads=2 \ + --build-arg torch_cuda_arch_list="9.0+PTX" \ + --build-arg vllm_fa_cmake_gpu_arches="90-real" + +# Setup cleanup +remove_docker_container() { docker rm -f gh200-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and test offline inference +docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' + python3 examples/offline_inference.py +' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 97aae233db105..529daf54faecf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -106,14 +106,12 @@ steps: source_file_dependencies: - vllm/ commands: - - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py - - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -201,7 +199,7 @@ steps: - python3 offline_inference_classification.py - python3 offline_inference_embedding.py - python3 offline_inference_scoring.py - - python3 offline_profile.py --model facebook/opt-125m + - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min mirror_hardwares: [amd] @@ -224,8 +222,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: @@ -329,8 +331,6 @@ steps: - vllm/ - tests/models commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py @@ -356,23 +356,25 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 28min +- label: Multi-Modal Models Test (Standard) # 40min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language - tests/models/embedding/vision_language + - tests/models/encoder_decoder/audio_language - tests/models/encoder_decoder/vision_language commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model + - pytest -v -s models/encoder_decoder/audio_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) 1 # 1h16m +- label: Multi-Modal Models Test (Extended) 1 # 48m optional: true source_file_dependencies: - vllm/ @@ -465,11 +467,28 @@ steps: - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py +- label: Plugin Tests (2 GPUs) # 40min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + fast_check: true + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # other tests continue here: + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh index 7345dd4e66b29..3c756659a715a 100644 --- a/.buildkite/upload-wheels.sh +++ b/.buildkite/upload-wheels.sh @@ -23,6 +23,8 @@ wheel="$new_wheel" version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2) echo "Version: $version" +normal_wheel="$wheel" # Save the original wheel filename + # If the version contains "dev", rename it to v1.0.0.dev for consistency if [[ $version == *dev* ]]; then suffix="${version##*.}" @@ -32,12 +34,38 @@ if [[ $version == *dev* ]]; then new_version="1.0.0.dev" fi new_wheel="${wheel/$version/$new_version}" - mv -- "$wheel" "$new_wheel" + # use cp to keep both files in the artifacts directory + cp -- "$wheel" "$new_wheel" wheel="$new_wheel" version="$new_version" fi # Upload the wheel to S3 +python3 .buildkite/generate_index.py --wheel "$normal_wheel" + +# generate index for this commit aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" +aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" + +if [[ $normal_wheel == *"cu118"* ]]; then + # if $normal_wheel matches cu118, do not upload the index.html + echo "Skipping index files for cu118 wheels" +else + # only upload index.html for cu12 wheels (default wheels) + aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" + aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" +fi + +# generate index for nightly aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" +aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" + +if [[ $normal_wheel == *"cu118"* ]]; then + # if $normal_wheel matches cu118, do not upload the index.html + echo "Skipping index files for cu118 wheels" +else + # only upload index.html for cu12 wheels (default wheels) + aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" +fi + aws s3 cp "$wheel" "s3://vllm-wheels/$version/" \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml similarity index 100% rename from .github/ISSUE_TEMPLATE/400-bug report.yml rename to .github/ISSUE_TEMPLATE/400-bug-report.yml diff --git a/.github/ISSUE_TEMPLATE/500-feature request.yml b/.github/ISSUE_TEMPLATE/500-feature-request.yml similarity index 100% rename from .github/ISSUE_TEMPLATE/500-feature request.yml rename to .github/ISSUE_TEMPLATE/500-feature-request.yml diff --git a/.github/ISSUE_TEMPLATE/600-new model.yml b/.github/ISSUE_TEMPLATE/600-new-model.yml similarity index 94% rename from .github/ISSUE_TEMPLATE/600-new model.yml rename to .github/ISSUE_TEMPLATE/600-new-model.yml index 794617a0cfdf6..713e76c1a5cec 100644 --- a/.github/ISSUE_TEMPLATE/600-new model.yml +++ b/.github/ISSUE_TEMPLATE/600-new-model.yml @@ -9,7 +9,7 @@ body: value: > #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). - #### We also highly recommend you read https://docs.vllm.ai/en/latest/models/adding_model.html first to understand how to add a new model. + #### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model. - type: textarea attributes: label: The model to consider. diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml similarity index 100% rename from .github/ISSUE_TEMPLATE/700-performance discussion.yml rename to .github/ISSUE_TEMPLATE/700-performance-discussion.yml diff --git a/.github/ISSUE_TEMPLATE/800-misc discussion.yml b/.github/ISSUE_TEMPLATE/800-misc-discussion.yml similarity index 100% rename from .github/ISSUE_TEMPLATE/800-misc discussion.yml rename to .github/ISSUE_TEMPLATE/800-misc-discussion.yml diff --git a/.gitignore b/.gitignore index ceef6a5fba456..bb7e4d5b244a8 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,8 @@ instance/ docs/_build/ docs/source/getting_started/examples/*.rst !**/*.template.rst +docs/source/getting_started/examples/*.md +!**/*.template.md # PyBuilder .pybuilder/ diff --git a/CMakeLists.txt b/CMakeLists.txt index ca7314ba4049a..84194a2ff5116 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,7 +240,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -257,7 +257,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG v3.6.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -275,7 +275,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor_entry.cu" + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -304,7 +307,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") @@ -357,6 +359,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # + # 2:4 Sparse Kernels + + # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor + # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now). + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + # # Machete kernels @@ -443,7 +470,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -583,7 +610,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb + GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/Dockerfile b/Dockerfile index 123703848749c..088314eb38dbe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ # to run the OpenAI compatible server. # Please update any changes made here to -# docs/source/dev/dockerfile/dockerfile.rst and +# docs/source/dev/dockerfile/dockerfile.md and # docs/source/assets/dev/dockerfile-stages-dependency.png ARG CUDA_VERSION=12.4.1 @@ -45,17 +45,21 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ WORKDIR /workspace # install build and runtime dependencies -COPY requirements-common.txt requirements-common.txt -COPY requirements-cuda.txt requirements-cuda.txt -COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-cuda.txt +# arm64 (GH200) build follows the practice of "use existing pytorch" build, +# we need to install torch and torchvision from the nightly builds first, +# pytorch will not appear as a vLLM dependency in all of the following steps +# after this step RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - python3 -m pip install -r requirements-cuda-arm64.txt; \ + python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ fi +COPY requirements-common.txt requirements-common.txt +COPY requirements-cuda.txt requirements-cuda.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install -r requirements-cuda.txt + # cuda arch list used by torch # can be useful for both `dev` and `test` # explicitly set the list to avoid issues with torch 2.2 @@ -77,11 +81,6 @@ COPY requirements-build.txt requirements-build.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-build.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - python3 -m pip install -r requirements-cuda-arm64.txt; \ - fi - COPY . . ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ @@ -157,8 +156,6 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM -COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt - RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment @@ -166,7 +163,7 @@ RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ - && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \ + && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ @@ -183,17 +180,20 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ # or future versions of triton. RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ +# arm64 (GH200) build follows the practice of "use existing pytorch" build, +# we need to install torch and torchvision from the nightly builds first, +# pytorch will not appear as a vLLM dependency in all of the following steps +# after this step +RUN --mount=type=cache,target=/root/.cache/pip \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ + fi + # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose -RUN --mount=type=cache,target=/root/.cache/pip \ - if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - pip uninstall -y torch && \ - python3 -m pip install -r requirements-cuda-arm64.txt; \ - fi - RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ @@ -234,17 +234,27 @@ RUN mv vllm test_docs/ #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### -# openai api server alternative -FROM vllm-base AS vllm-openai +# base openai image with additional requirements, for any subsequent openai-style images +FROM vllm-base AS vllm-openai-base # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ else \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ fi + ENV VLLM_USAGE_SOURCE production-docker-image +# define sagemaker first, so it is not default from `docker build` +FROM vllm-openai-base AS vllm-sagemaker + +COPY examples/sagemaker-entrypoint.sh . +RUN chmod +x sagemaker-entrypoint.sh +ENTRYPOINT ["./sagemaker-entrypoint.sh"] + +FROM vllm-openai-base AS vllm-openai + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/Dockerfile.cpu b/Dockerfile.cpu index ebe226cf6d148..f163edc27cba8 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0 WORKDIR /workspace +COPY requirements-build.txt requirements-build.txt ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ pip install --upgrade pip && \ pip install -r requirements-build.txt @@ -37,9 +37,9 @@ FROM cpu-test-1 AS build WORKDIR /workspace/vllm +COPY requirements-common.txt requirements-common.txt +COPY requirements-cpu.txt requirements-cpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \ - --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \ pip install -v -r requirements-cpu.txt COPY . . diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 77162bc82de62..269139fe90f0b 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,6 +1,6 @@ # default base image # https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04" FROM $BASE_IMAGE @@ -22,9 +22,9 @@ WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas -RUN python3 -m pip install sentencepiece transformers==4.36.2 -U +RUN python3 -m pip install sentencepiece transformers==4.45.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U COPY . . ARG GIT_REPO_CHECK=0 diff --git a/README.md b/README.md index 93b71ddaccc61..f83c9d759b359 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ vLLM is flexible and easy to use with: vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) -- Mixture-of-Expert LLMs (e.g., Mixtral) +- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3) - Embedding Models (e.g. E5-Mistral) - Multi-modal LLMs (e.g., LLaVA) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py new file mode 100644 index 0000000000000..13477ef535e86 --- /dev/null +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -0,0 +1,184 @@ +""" +Offline benchmark to test the long document QA throughput. + +Example usage: + # This command run the vllm with 50GB CPU memory for offloading + # The workload samples 8 different prompts with a default input + # length of 20000 tokens, then replicates each prompt 2 times + # in random order. + python benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --repeat-count 2 + +Commandline arguments: + --num-documents: The number of documents to sample prompts from. + + --document-length: The length of each document in tokens. + (Optional, default: 20000) + + --output-len: The number of tokens to generate for each prompt. + (Optional, default: 10) + + --repeat-count: The number of times to repeat each prompt. + (Optional, default: 2) + + --repeat-mode: The mode to repeat prompts. The supported modes are: + - 'random': shuffle the prompts randomly. (Default) + - 'tile': the entire prompt list is repeated in sequence. (Potentially + lowest cache hit) + - 'interleave': each prompt is repeated consecutively before + moving to the next element. (Highest cache hit) + + --shuffle-seed: Random seed when the repeat mode is "random". + (Optional, default: 0) + +In the meantime, it also supports all the vLLM engine args to initialize the +LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more +details. +""" + +import dataclasses +import random +import time + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + """ + Test long document QA with the given prompts and sampling parameters. + Print the time spent in processing all the prompts. + + Args: + llm: The language model used for generating responses. + sampling_params: Sampling parameter used to generate the response. + prompts: A list of prompt strings to be processed by the LLM. + """ + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"Time to execute all requests: {end_time - start_time:.4f} secs") + + +def repeat_prompts(prompts, repeat_count, mode: str): + """ + Repeat each prompt in the list for a specified number of times. + The order of prompts in the output list depends on the mode. + + Args: + prompts: A list of prompts to be repeated. + repeat_count: The number of times each prompt is repeated. + mode: The mode of repetition. Supported modes are: + - 'random': Shuffle the prompts randomly after repetition. + - 'tile': Repeat the entire prompt list in sequence. + Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. + - 'interleave': Repeat each prompt consecutively before moving to + the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. + + Returns: + A list of repeated prompts in the specified order. + + Raises: + ValueError: If an invalid mode is provided. + """ + print("Repeat mode: ", mode) + if mode == 'random': + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + elif mode == 'tile': + return prompts * repeat_count + elif mode == 'interleave': + repeated_prompts = [] + for prompt in prompts: + repeated_prompts.extend([prompt] * repeat_count) + return repeated_prompts + else: + raise ValueError(f"Invalid mode: {mode}, only support " + "'random', 'tile', 'interleave'") + + +def main(args): + random.seed(args.shuffle_seed) + + # Prepare the prompts: + # we append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) + + warmup_prompts = [ + "This is warm up request " + str(i) + \ + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents)] + + # Create the LLM engine + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=warmup_prompts, + sampling_params=sampling_params, + ) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20000, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--output-len', type=int, default=10) + + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + + parser.add_argument("--repeat-mode", + type=str, + default='random', + help='The mode to repeat prompts. The supported ' + 'modes are "random", "tile", and "interleave". ' + 'See repeat_prompts() in the source code for details.') + + parser.add_argument("--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1e5967bd9bf8b..c1b10b3cf8f58 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,7 +4,8 @@ import json import random import time -from typing import List, Optional +from functools import cache +from typing import Dict, List, Optional, Tuple import torch import uvloop @@ -17,8 +18,11 @@ from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.inputs import TextPrompt +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -28,15 +32,17 @@ class SampleRequest: Attributes: prompt: The input text prompt for the model. - multi_modal_data: Optional dictionary containing multi-modal data (e.g. - images). prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + lora_request: Optional LoRARequest specifying the LoRA to use. """ prompt: str prompt_len: int expected_output_len: int multi_modal_data: Optional[MultiModalDataDict] = None + lora_request: Optional[LoRARequest] = None def _get_prompt_for_image_model(question: str, *, model: str) -> str: @@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str: raise ValueError(f"Unsupported model {model}") +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} + + +def get_random_lora_request( + args: argparse.Namespace +) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: + global lora_tokenizer_cache + lora_id = random.randint(1, args.max_loras) + lora_request = LoRARequest(lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(args.lora_path)) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + return lora_request, lora_tokenizer_cache[lora_id] + + def sample_requests(tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset num_requests: int = args.num_prompts fixed_output_len: Optional[int] = args.output_len @@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] - for data in dataset: + for data in tqdm(dataset, + total=len(filtered_dataset), + desc="sampling requests"): if len(filtered_dataset) == num_requests: break @@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, continue prompt = _get_prompt_for_image_model(question=prompt, model=model) + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Tokenize the prompts and completions. - prompt_token_ids = tokenizer(prompt).input_ids - completion_token_ids = tokenizer(completion).input_ids + prompt_token_ids = request_tokenizer(prompt).input_ids + completion_token_ids = request_tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len @@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, SampleRequest(prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - multi_modal_data=multi_modal_data)) + multi_modal_data=multi_modal_data, + lora_request=lora_request)) return filtered_dataset @@ -146,14 +184,21 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests: Optional[List[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] use_beam_search = False if not use_beam_search: start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) end = time.perf_counter() else: + assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0][2] @@ -185,6 +230,7 @@ async def run_vllm_async( # Add the requests to the engine. prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] + lora_requests: List[Optional[LoRARequest]] = [] for request in requests: prompts.append( TextPrompt(prompt=request.prompt, @@ -197,11 +243,16 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = llm.generate(prompt, sp, request_id=f"test{i}") + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -297,6 +348,14 @@ def main(args: argparse.Namespace): vocab_size = tokenizer.vocab_size requests = [] for _ in range(args.num_prompts): + + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Synthesize a prompt with the given input length. candidate_ids = [ random.randint(0, vocab_size - 1) @@ -305,8 +364,8 @@ def main(args: argparse.Namespace): # As tokenizer may add additional tokens like BOS, we need to try # different lengths to get the desired input length. for _ in range(5): # Max attempts to correct - candidate_prompt = tokenizer.decode(candidate_ids) - tokenized_len = len(tokenizer.encode(candidate_prompt)) + candidate_prompt = request_tokenizer.decode(candidate_ids) + tokenized_len = len(request_tokenizer.encode(candidate_prompt)) if tokenized_len == args.input_len: break @@ -323,7 +382,8 @@ def main(args: argparse.Namespace): requests.append( SampleRequest(prompt=candidate_prompt, prompt_len=args.input_len, - expected_output_len=args.output_len)) + expected_output_len=args.output_len, + lora_request=lora_request)) else: requests = sample_requests(tokenizer, args) @@ -422,6 +482,14 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: @@ -431,6 +499,8 @@ def main(args: argparse.Namespace): assert args.output_len is not None else: assert args.input_len is None + if args.enable_lora: + assert args.lora_path is not None if args.backend == "vllm": if args.hf_max_batch_size is not None: @@ -440,6 +510,9 @@ def main(args: argparse.Namespace): raise ValueError("HF max batch size is required for HF backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") elif args.backend == "mii": if args.dtype != "auto": raise ValueError("dtype must be auto for MII backend.") @@ -452,4 +525,7 @@ def main(args: argparse.Namespace): if args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII " "backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") main(args) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..3d1c5e392f9e2 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,384 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000..ef06fcd6604dd --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,96 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..d0353bc8cb42a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -8,6 +8,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops @@ -17,31 +18,6 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - # bench def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, @@ -386,4 +362,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 2924ea4a49f54..94999630bae12 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -10,7 +10,8 @@ set -ex kill_gpu_processes() { # kill all processes on GPU. - pkill -f pt_main_thread + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 sleep 10 # remove vllm config file @@ -54,7 +55,7 @@ benchmark() { CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --model $model \ --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ @@ -64,7 +65,7 @@ benchmark() { CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --model $model \ --port 8200 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ @@ -87,7 +88,7 @@ benchmark() { --port 8100 \ --save-result \ --result-dir $results_folder \ - --result-filename disagg_prefill_2xtp4.json \ + --result-filename disagg_prefill_tp1.json \ --request-rate "inf" @@ -105,7 +106,7 @@ benchmark() { --port 8200 \ --save-result \ --result-dir $results_folder \ - --result-filename disagg_prefill_2xtp4.json \ + --result-filename disagg_prefill_tp1_overhead.json \ --request-rate "$qps" kill_gpu_processes @@ -118,7 +119,7 @@ main() { (which jq) || (apt-get -y install jq) (which socat) || (apt-get -y install socat) - pip install quart httpx + pip install quart httpx datasets cd "$(dirname "$0")" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index d8d9e976dce76..eb5d891d0d4a5 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -1,13 +1,12 @@ #!/bin/bash -# Requirement: 8x H100 GPUs. +# Requirement: 2x GPUs. -# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV -# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests -# Resource: 8x H100 +# Model: meta-llama/Meta-Llama-3.1-8B-Instruct +# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests +# Resource: 2x GPU # Approaches: -# 1. Chunked prefill: 1 vllm instance with tp=8 # 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 # 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance # Prefilling instance: max_output_token=1 @@ -114,7 +113,6 @@ benchmark() { --request-rate "$qps" sleep 2 - } @@ -123,8 +121,9 @@ main() { (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get -y install jq) (which socat) || (apt-get -y install socat) + (which lsof) || (apt-get -y install lsof) - pip install quart httpx matplotlib aiohttp + pip install quart httpx matplotlib aiohttp datasets cd "$(dirname "$0")" diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py new file mode 100644 index 0000000000000..baa5de0fff1bd --- /dev/null +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -0,0 +1,262 @@ +import itertools +from typing import Optional, Tuple, Union + +import torch +import triton +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn + +from vllm import _custom_ops as vllm_ops + + +class HuggingFaceRMSNorm(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm( + x.clone(), weight, + residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + + if torch.allclose(output_naive, output_flashinfer, atol=1e-2, + rtol=1e-2) and torch.allclose( + output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name= + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 5cdd250c3f9cf..3569b3c88abcd 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -53,7 +53,7 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, torch::Tensor& k_scale, + const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { @@ -194,7 +194,7 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, + const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index c0e6b7cfd67a0..bc543e713fe58 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -62,7 +62,7 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, torch::Tensor& k_scale, + const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { @@ -213,7 +213,7 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, + const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index e21832ba7582f..ef5b14088c63b 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes) { + const std::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -459,7 +459,7 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, + int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes) { + int max_seq_len, const std::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -781,7 +781,7 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, + int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index d9aed657a3113..33b1637832888 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major const torch::Tensor& b, // [IC, OC], column-major const torch::Tensor& a_scales, // [1] or [M] const torch::Tensor& b_scales, // [1] or [OC] - const c10::optional& bias // [OC] + const std::optional& bias // [OC] ) { CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) // Checks for conformality @@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major const torch::Tensor& a_scales, // [1] or [M] const torch::Tensor& b_scales, // [1] or [OC] const torch::Tensor& azp_adj, // [OC] - const c10::optional& azp, // [1] or [M] - const c10::optional& bias // [OC] + const std::optional& azp, // [1] or [M] + const std::optional& bias // [OC] ) { CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) // Checks for conformality @@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] const torch::Tensor& scale, - c10::optional const& azp) { + std::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); @@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] torch::Tensor& scale, // [..., 1] - c10::optional const& azp) { + std::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 03beefbc6de7d..74e4d8189d403 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_scales, const torch::Tensor& b_scales, - const c10::optional& bias); + const std::optional& bias); void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& azp_adj, - const c10::optional& azp, - const c10::optional& bias); + const std::optional& azp, + const std::optional& bias); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 0000000000000..85e359aa57113 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index c69e87999ae71..ef413e6dd75c5 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" /* @@ -66,7 +68,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(c10::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { static_assert(std::is_same_v>); using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; @@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -299,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken torch::Tensor const& b_scales, torch::Tensor const& azp_adj, torch::Tensor const& azp, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..c590c66a66652 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" /* @@ -36,13 +38,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors @@ -65,7 +67,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(c10::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; static_assert(std::is_same_v> || @@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -297,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken torch::Tensor const& b_scales, torch::Tensor const& azp_adj, torch::Tensor const& azp, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 2c78572521eec..a1ff933cce63f 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, template static inline auto maybe_make_cute_layout( - c10::optional const& tensor, + std::optional const& tensor, std::string_view name = "tensor") { using Layout = decltype(make_cute_layout(*tensor)); diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index a5beea1a35e49..b401736c9824b 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum): class MixedInputKernelScheduleType(enum.Enum): - TmaWarpSpecializedMixedInput = enum_auto() - TmaWarpSpecializedPingpongMixedInput = enum_auto() - TmaWarpSpecializedCooperativeMixedInput = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { @@ -68,11 +68,11 @@ class MixedInputKernelScheduleType(enum.Enum): MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore **{ - MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecialized: + "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: + "cutlass::gemm::KernelTmaWarpSpecializedCooperative", } } diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index dd1e6de2e0180..f0e5533bcae60 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor x, const at::Tensor weight, const at::Tensor out, - const c10::optional& bias, + const std::optional& bias, bool silu_activation, int64_t pad_slot_id, - const c10::optional& query_start_loc = std::nullopt, - const c10::optional& cache_indices = std::nullopt, - const c10::optional& has_initial_state = std::nullopt) { + const std::optional& query_start_loc = std::nullopt, + const std::optional& cache_indices = std::nullopt, + const std::optional& has_initial_state = std::nullopt) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, - const c10::optional &bias_, - const c10::optional &conv_states, - const c10::optional &query_start_loc, - const c10::optional &cache_indices, - const c10::optional &has_initial_state, + const std::optional &bias_, + const std::optional &conv_states, + const std::optional &query_start_loc, + const std::optional &cache_indices, + const std::optional &has_initial_state, bool silu_activation, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early @@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, void causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, - const c10::optional &bias_, + const std::optional &bias_, bool silu_activation, - const c10::optional &cache_seqlens_, - const c10::optional &conv_state_indices_, + const std::optional &cache_seqlens_, + const std::optional &conv_state_indices_, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early int64_t pad_slot_id) { diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 71624696338d0..bd0a34119c82b 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const torch::Tensor out, const torch::Tensor z, const torch::Tensor out_z, - const c10::optional& D, - const c10::optional& delta_bias, + const std::optional& D, + const std::optional& delta_bias, const torch::Tensor ssm_states, bool has_z, bool delta_softplus, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, bool varlen, int64_t pad_slot_id) { @@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, + const std::optional &D_, + const std::optional &z_, + const std::optional &delta_bias_, bool delta_softplus, - const c10::optional &query_start_loc, - const c10::optional &cache_indices, - const c10::optional &has_initial_state, + const std::optional &query_start_loc, + const std::optional &cache_indices, + const std::optional &has_initial_state, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index dd90c38d9a721..16fccae403338 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -112,6 +112,91 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, } } +// TODO(simon): this is temporarily adapted from +// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 +// we did this to unblock Deepseek V3 but there should be a better +// implementation to manage shared memory. +template +__global__ void moe_align_block_size_global_mem_kernel( + scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { + tokens_cnts[index(num_experts, 0, eid)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, eid)] += + tokens_cnts[index(num_experts, i - 1, eid)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { + for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) { + expert_ids[i / block_size] = eid; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] diff --git a/csrc/ops.h b/csrc/ops.h index 8ca912ff58897..e9cc8d2e215e2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,7 +33,7 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, + int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, @@ -45,7 +45,7 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, + int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, @@ -158,24 +158,35 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias); + std::optional const& bias); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& azp, - c10::optional const& bias); + std::optional const& azp, + std::optional const& bias); + +bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, - c10::optional const& azp); + std::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, - c10::optional const& azp); + std::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, @@ -192,34 +203,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, - c10::optional const& scale_ub); + std::optional const& scale_ub); void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, - const c10::optional& D_, - const c10::optional& z_, - const c10::optional& delta_bias_, + const std::optional& D_, + const std::optional& z_, + const std::optional& delta_bias_, bool delta_softplus, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias_, + const std::optional& bias_, bool silu_activation, - const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_, + const std::optional& cache_seqlens_, + const std::optional& conv_state_indices_, int64_t pad_slot_id); void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& conv_states, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, + const std::optional& bias_, + const std::optional& conv_states, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, bool silu_activation, int64_t pad_slot_id); using fptr_t = int64_t; diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index e9987535bd3ea..e79785827189d 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& scale, - c10::optional const& azp) { + std::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); @@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales, c10::optional const& azp) { + torch::Tensor& scales, std::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scales.is_contiguous()); diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index dbb72e8bbd3f5..865fef5aeea11 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias) { + std::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (bias) { @@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& azp, - c10::optional const& bias) { + std::optional const& azp, + std::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); @@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias) { + std::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (bias) { @@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& azp, - c10::optional const& bias) { + std::optional const& azp, + std::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); @@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias) { + std::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (bias) { @@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& azp, - c10::optional const& bias) { + std::optional const& azp, + std::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..f2fae4b66d651 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,15 +21,16 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; /* - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, + Epilogues defined in, + csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp + must contain a public type named EVTCompute of type Sm80EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..e18d7d79e5b77 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,384 +1,18 @@ -// clang-format will break include orders -// clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include + #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh" + #include "scaled_mm_c3x_sm90_int8_dispatch.cuh" -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; + #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. */ -namespace { - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; - - struct GemmKernel : public KernelType {}; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - if (mp2 <= 64) { - // m in [1, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_int8_config_M128::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_int8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; - using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - - uint32_t const n = out.size(1); - bool const is_small_n = n < 8192; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 - - if (mp2 <= 32) { - // m in [1, 32] - if (is_small_n) { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else if (mp2 <= 64) { - // m in (32, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - template