diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py
new file mode 100644
index 0000000000000..8350e2705141e
--- /dev/null
+++ b/.buildkite/generate_index.py
@@ -0,0 +1,24 @@
+import argparse
+import os
+
+template = """
+
+
+ Links for vLLM
+ {wheel}
+
+
+"""
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--wheel", help="The wheel path.", required=True)
+args = parser.parse_args()
+
+filename = os.path.basename(args.wheel)
+
+with open("index.html", "w") as f:
+ print(f"Generated index.html for {args.wheel}")
+ # cloudfront requires escaping the '+' character
+ f.write(
+ template.format(wheel=filename,
+ wheel_html_escaped=filename.replace("+", "%2B")))
diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
index 64ba1b32fb074..679abf1814aa5 100644
--- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
+++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
@@ -1,5 +1,6 @@
steps:
- label: "Wait for container to be ready"
+ key: wait-for-container-image
agents:
queue: A100
plugins:
@@ -10,12 +11,11 @@ steps:
command:
- sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh
- - wait
-
- label: "A100"
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: A100
+ depends_on: wait-for-container-image
plugins:
- kubernetes:
podSpec:
@@ -49,6 +49,7 @@ steps:
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: H200
+ depends_on: wait-for-container-image
plugins:
- docker#v5.12.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
@@ -65,15 +66,15 @@ steps:
- VLLM_USAGE_SOURCE
- HF_TOKEN
- - block: "Run H100 Benchmark"
- key: block-h100
- depends_on: ~
+ #- block: "Run H100 Benchmark"
+ #key: block-h100
+ #depends_on: ~
- label: "H100"
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: H100
- depends_on: block-h100
+ depends_on: wait-for-container-image
plugins:
- docker#v5.12.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index 2de6fceb0c3fe..51618a2955fb1 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -55,3 +55,18 @@ steps:
password-env: DOCKERHUB_TOKEN
env:
DOCKER_BUILDKIT: "1"
+
+ - block: "Build CPU release image"
+ key: block-cpu-release-image-build
+ depends_on: ~
+
+ - label: "Build and publish CPU release image"
+ depends_on: block-cpu-release-image-build
+ agents:
+ queue: cpu_queue_postmerge
+ commands:
+ - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
+ - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
+ env:
+ DOCKER_BUILDKIT: "1"
diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh
new file mode 100644
index 0000000000000..4fc6d089cc666
--- /dev/null
+++ b/.buildkite/run-gh200-test.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# This script build the GH200 docker image and run the offline inference inside the container.
+# It serves a sanity check for compilation and basic model usage.
+set -ex
+
+# Skip the new torch installation during build since we are using the specified version for arm64 in the Dockerfile
+python3 use_existing_torch.py
+
+# Try building the docker image
+DOCKER_BUILDKIT=1 docker build . \
+ --target vllm-openai \
+ --platform "linux/arm64" \
+ -t gh200-test \
+ --build-arg max_jobs=66 \
+ --build-arg nvcc_threads=2 \
+ --build-arg torch_cuda_arch_list="9.0+PTX" \
+ --build-arg vllm_fa_cmake_gpu_arches="90-real"
+
+# Setup cleanup
+remove_docker_container() { docker rm -f gh200-test || true; }
+trap remove_docker_container EXIT
+remove_docker_container
+
+# Run the image and test offline inference
+docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
+ python3 examples/offline_inference.py
+'
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 97aae233db105..529daf54faecf 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -106,14 +106,12 @@ steps:
source_file_dependencies:
- vllm/
commands:
- - pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
- - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
@@ -201,7 +199,7 @@ steps:
- python3 offline_inference_classification.py
- python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py
- - python3 offline_profile.py --model facebook/opt-125m
+ - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min
mirror_hardwares: [amd]
@@ -224,8 +222,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
+ - vllm/model_executor/guided_decoding
- tests/test_logits_processor
- command: pytest -v -s test_logits_processor.py
+ - tests/model_executor/test_guided_processors
+ commands:
+ - pytest -v -s test_logits_processor.py
+ - pytest -v -s model_executor/test_guided_processors.py
- label: Speculative decoding tests # 30min
source_file_dependencies:
@@ -329,8 +331,6 @@ steps:
- vllm/
- tests/models
commands:
- - pip install -e ./plugins/vllm_add_dummy_model
- - pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py
@@ -356,23 +356,25 @@ steps:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
-- label: Multi-Modal Models Test (Standard) # 28min
+- label: Multi-Modal Models Test (Standard) # 40min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
+ - tests/models/encoder_decoder/audio_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
+ - pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model
-- label: Multi-Modal Models Test (Extended) 1 # 1h16m
+- label: Multi-Modal Models Test (Extended) 1 # 48m
optional: true
source_file_dependencies:
- vllm/
@@ -465,11 +467,28 @@ steps:
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- - pip install -e ./plugins/vllm_add_dummy_model
- - pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
+- label: Plugin Tests (2 GPUs) # 40min
+ working_dir: "/vllm-workspace/tests"
+ num_gpus: 2
+ fast_check: true
+ source_file_dependencies:
+ - vllm/plugins/
+ - tests/plugins/
+ commands:
+ # begin platform plugin tests, all the code in-between runs on dummy platform
+ - pip install -e ./plugins/vllm_add_dummy_platform
+ - pytest -v -s plugins_tests/test_platform_plugins.py
+ - pip uninstall vllm_add_dummy_platform -y
+ # end platform plugin tests
+ # other tests continue here:
+ - pip install -e ./plugins/vllm_add_dummy_model
+ - pytest -v -s distributed/test_distributed_oot.py
+ - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
+ - pytest -v -s models/test_oot_registration.py # it needs a clean process
+
- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh
index 7345dd4e66b29..3c756659a715a 100644
--- a/.buildkite/upload-wheels.sh
+++ b/.buildkite/upload-wheels.sh
@@ -23,6 +23,8 @@ wheel="$new_wheel"
version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2)
echo "Version: $version"
+normal_wheel="$wheel" # Save the original wheel filename
+
# If the version contains "dev", rename it to v1.0.0.dev for consistency
if [[ $version == *dev* ]]; then
suffix="${version##*.}"
@@ -32,12 +34,38 @@ if [[ $version == *dev* ]]; then
new_version="1.0.0.dev"
fi
new_wheel="${wheel/$version/$new_version}"
- mv -- "$wheel" "$new_wheel"
+ # use cp to keep both files in the artifacts directory
+ cp -- "$wheel" "$new_wheel"
wheel="$new_wheel"
version="$new_version"
fi
# Upload the wheel to S3
+python3 .buildkite/generate_index.py --wheel "$normal_wheel"
+
+# generate index for this commit
aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
+aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
+
+if [[ $normal_wheel == *"cu118"* ]]; then
+ # if $normal_wheel matches cu118, do not upload the index.html
+ echo "Skipping index files for cu118 wheels"
+else
+ # only upload index.html for cu12 wheels (default wheels)
+ aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
+ aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
+fi
+
+# generate index for nightly
aws s3 cp "$wheel" "s3://vllm-wheels/nightly/"
+aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
+
+if [[ $normal_wheel == *"cu118"* ]]; then
+ # if $normal_wheel matches cu118, do not upload the index.html
+ echo "Skipping index files for cu118 wheels"
+else
+ # only upload index.html for cu12 wheels (default wheels)
+ aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
+fi
+
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml
similarity index 100%
rename from .github/ISSUE_TEMPLATE/400-bug report.yml
rename to .github/ISSUE_TEMPLATE/400-bug-report.yml
diff --git a/.github/ISSUE_TEMPLATE/500-feature request.yml b/.github/ISSUE_TEMPLATE/500-feature-request.yml
similarity index 100%
rename from .github/ISSUE_TEMPLATE/500-feature request.yml
rename to .github/ISSUE_TEMPLATE/500-feature-request.yml
diff --git a/.github/ISSUE_TEMPLATE/600-new model.yml b/.github/ISSUE_TEMPLATE/600-new-model.yml
similarity index 94%
rename from .github/ISSUE_TEMPLATE/600-new model.yml
rename to .github/ISSUE_TEMPLATE/600-new-model.yml
index 794617a0cfdf6..713e76c1a5cec 100644
--- a/.github/ISSUE_TEMPLATE/600-new model.yml
+++ b/.github/ISSUE_TEMPLATE/600-new-model.yml
@@ -9,7 +9,7 @@ body:
value: >
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
- #### We also highly recommend you read https://docs.vllm.ai/en/latest/models/adding_model.html first to understand how to add a new model.
+ #### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
- type: textarea
attributes:
label: The model to consider.
diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml
similarity index 100%
rename from .github/ISSUE_TEMPLATE/700-performance discussion.yml
rename to .github/ISSUE_TEMPLATE/700-performance-discussion.yml
diff --git a/.github/ISSUE_TEMPLATE/800-misc discussion.yml b/.github/ISSUE_TEMPLATE/800-misc-discussion.yml
similarity index 100%
rename from .github/ISSUE_TEMPLATE/800-misc discussion.yml
rename to .github/ISSUE_TEMPLATE/800-misc-discussion.yml
diff --git a/.gitignore b/.gitignore
index ceef6a5fba456..bb7e4d5b244a8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -81,6 +81,8 @@ instance/
docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst
+docs/source/getting_started/examples/*.md
+!**/*.template.md
# PyBuilder
.pybuilder/
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ca7314ba4049a..84194a2ff5116 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -240,7 +240,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
- set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
+ set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -257,7 +257,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
- GIT_TAG v3.5.1
+ GIT_TAG v3.6.0
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
@@ -275,7 +275,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
- "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
+ "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
+ "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
+ "csrc/sparse/cutlass/sparse_compressor_entry.cu"
+ "csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@@ -304,7 +307,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
- #
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
@@ -357,6 +359,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
+ #
+ # 2:4 Sparse Kernels
+
+ # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
+ # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
+ set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
+ "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
+ set_gencode_flags_for_srcs(
+ SRCS "${SRCS}"
+ CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
+ list(APPEND VLLM_EXT_SRC "${SRCS}")
+ list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
+ message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
+ else()
+ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
+ message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
+ "not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
+ "if you intend on running FP8 sparse quantized models on Hopper.")
+ else()
+ message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
+ "in CUDA target architectures")
+ endif()
+ endif()
+
#
# Machete kernels
@@ -443,7 +470,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
- INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
+ INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
@@ -583,7 +610,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
- GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
+ GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
diff --git a/Dockerfile b/Dockerfile
index 123703848749c..088314eb38dbe 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@
# to run the OpenAI compatible server.
# Please update any changes made here to
-# docs/source/dev/dockerfile/dockerfile.rst and
+# docs/source/dev/dockerfile/dockerfile.md and
# docs/source/assets/dev/dockerfile-stages-dependency.png
ARG CUDA_VERSION=12.4.1
@@ -45,17 +45,21 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
WORKDIR /workspace
# install build and runtime dependencies
-COPY requirements-common.txt requirements-common.txt
-COPY requirements-cuda.txt requirements-cuda.txt
-COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- python3 -m pip install -r requirements-cuda.txt
+# arm64 (GH200) build follows the practice of "use existing pytorch" build,
+# we need to install torch and torchvision from the nightly builds first,
+# pytorch will not appear as a vLLM dependency in all of the following steps
+# after this step
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- python3 -m pip install -r requirements-cuda-arm64.txt; \
+ python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
fi
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cuda.txt requirements-cuda.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+ python3 -m pip install -r requirements-cuda.txt
+
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
@@ -77,11 +81,6 @@ COPY requirements-build.txt requirements-build.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- python3 -m pip install -r requirements-cuda-arm64.txt; \
- fi
-
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
@@ -157,8 +156,6 @@ WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM
-COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
-
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
@@ -166,7 +163,7 @@ RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
- && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
+ && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
@@ -183,17 +180,20 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
+# arm64 (GH200) build follows the practice of "use existing pytorch" build,
+# we need to install torch and torchvision from the nightly builds first,
+# pytorch will not appear as a vLLM dependency in all of the following steps
+# after this step
+RUN --mount=type=cache,target=/root/.cache/pip \
+ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
+ python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
+ fi
+
# Install vllm wheel first, so that torch etc will be installed.
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose
-RUN --mount=type=cache,target=/root/.cache/pip \
- if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- pip uninstall -y torch && \
- python3 -m pip install -r requirements-cuda-arm64.txt; \
- fi
-
RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
@@ -234,17 +234,27 @@ RUN mv vllm test_docs/
#################### TEST IMAGE ####################
#################### OPENAI API SERVER ####################
-# openai api server alternative
-FROM vllm-base AS vllm-openai
+# base openai image with additional requirements, for any subsequent openai-style images
+FROM vllm-base AS vllm-openai-base
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \
+ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
else \
- pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \
+ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
fi
+
ENV VLLM_USAGE_SOURCE production-docker-image
+# define sagemaker first, so it is not default from `docker build`
+FROM vllm-openai-base AS vllm-sagemaker
+
+COPY examples/sagemaker-entrypoint.sh .
+RUN chmod +x sagemaker-entrypoint.sh
+ENTRYPOINT ["./sagemaker-entrypoint.sh"]
+
+FROM vllm-openai-base AS vllm-openai
+
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
#################### OPENAI API SERVER ####################
diff --git a/Dockerfile.cpu b/Dockerfile.cpu
index ebe226cf6d148..f163edc27cba8 100644
--- a/Dockerfile.cpu
+++ b/Dockerfile.cpu
@@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0
WORKDIR /workspace
+COPY requirements-build.txt requirements-build.txt
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
RUN --mount=type=cache,target=/root/.cache/pip \
- --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
pip install --upgrade pip && \
pip install -r requirements-build.txt
@@ -37,9 +37,9 @@ FROM cpu-test-1 AS build
WORKDIR /workspace/vllm
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
- --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
- --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
pip install -v -r requirements-cpu.txt
COPY . .
diff --git a/Dockerfile.neuron b/Dockerfile.neuron
index 77162bc82de62..269139fe90f0b 100644
--- a/Dockerfile.neuron
+++ b/Dockerfile.neuron
@@ -1,6 +1,6 @@
# default base image
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
-ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04"
+ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04"
FROM $BASE_IMAGE
@@ -22,9 +22,9 @@ WORKDIR ${APP_MOUNT}/vllm
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
-RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
+RUN python3 -m pip install sentencepiece transformers==4.45.2 -U
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
-RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
+RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
COPY . .
ARG GIT_REPO_CHECK=0
diff --git a/README.md b/README.md
index 93b71ddaccc61..f83c9d759b359 100644
--- a/README.md
+++ b/README.md
@@ -60,7 +60,7 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
-- Mixture-of-Expert LLMs (e.g., Mixtral)
+- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
- Embedding Models (e.g. E5-Mistral)
- Multi-modal LLMs (e.g., LLaVA)
diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py
new file mode 100644
index 0000000000000..13477ef535e86
--- /dev/null
+++ b/benchmarks/benchmark_long_document_qa_throughput.py
@@ -0,0 +1,184 @@
+"""
+Offline benchmark to test the long document QA throughput.
+
+Example usage:
+ # This command run the vllm with 50GB CPU memory for offloading
+ # The workload samples 8 different prompts with a default input
+ # length of 20000 tokens, then replicates each prompt 2 times
+ # in random order.
+ python benchmark_long_document_qa_throughput.py \
+ --model meta-llama/Llama-2-7b-chat-hf \
+ --enable-prefix-caching \
+ --num-documents 8 \
+ --repeat-count 2
+
+Commandline arguments:
+ --num-documents: The number of documents to sample prompts from.
+
+ --document-length: The length of each document in tokens.
+ (Optional, default: 20000)
+
+ --output-len: The number of tokens to generate for each prompt.
+ (Optional, default: 10)
+
+ --repeat-count: The number of times to repeat each prompt.
+ (Optional, default: 2)
+
+ --repeat-mode: The mode to repeat prompts. The supported modes are:
+ - 'random': shuffle the prompts randomly. (Default)
+ - 'tile': the entire prompt list is repeated in sequence. (Potentially
+ lowest cache hit)
+ - 'interleave': each prompt is repeated consecutively before
+ moving to the next element. (Highest cache hit)
+
+ --shuffle-seed: Random seed when the repeat mode is "random".
+ (Optional, default: 0)
+
+In the meantime, it also supports all the vLLM engine args to initialize the
+LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
+details.
+"""
+
+import dataclasses
+import random
+import time
+
+from vllm import LLM, SamplingParams
+from vllm.engine.arg_utils import EngineArgs
+from vllm.utils import FlexibleArgumentParser
+
+
+def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
+ """
+ Test long document QA with the given prompts and sampling parameters.
+ Print the time spent in processing all the prompts.
+
+ Args:
+ llm: The language model used for generating responses.
+ sampling_params: Sampling parameter used to generate the response.
+ prompts: A list of prompt strings to be processed by the LLM.
+ """
+ start_time = time.time()
+ llm.generate(prompts, sampling_params=sampling_params)
+ end_time = time.time()
+ print(f"Time to execute all requests: {end_time - start_time:.4f} secs")
+
+
+def repeat_prompts(prompts, repeat_count, mode: str):
+ """
+ Repeat each prompt in the list for a specified number of times.
+ The order of prompts in the output list depends on the mode.
+
+ Args:
+ prompts: A list of prompts to be repeated.
+ repeat_count: The number of times each prompt is repeated.
+ mode: The mode of repetition. Supported modes are:
+ - 'random': Shuffle the prompts randomly after repetition.
+ - 'tile': Repeat the entire prompt list in sequence.
+ Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
+ - 'interleave': Repeat each prompt consecutively before moving to
+ the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
+
+ Returns:
+ A list of repeated prompts in the specified order.
+
+ Raises:
+ ValueError: If an invalid mode is provided.
+ """
+ print("Repeat mode: ", mode)
+ if mode == 'random':
+ repeated_prompts = prompts * repeat_count
+ random.shuffle(repeated_prompts)
+ return repeated_prompts
+ elif mode == 'tile':
+ return prompts * repeat_count
+ elif mode == 'interleave':
+ repeated_prompts = []
+ for prompt in prompts:
+ repeated_prompts.extend([prompt] * repeat_count)
+ return repeated_prompts
+ else:
+ raise ValueError(f"Invalid mode: {mode}, only support "
+ "'random', 'tile', 'interleave'")
+
+
+def main(args):
+ random.seed(args.shuffle_seed)
+
+ # Prepare the prompts:
+ # we append the document id at the beginning to avoid any of the document
+ # being the prefix of other documents
+ prompts = [
+ str(i) + ' '.join(['hi'] * args.document_length)
+ for i in range(args.num_documents)
+ ]
+
+ prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
+
+ warmup_prompts = [
+ "This is warm up request " + str(i) + \
+ ' '.join(['hi'] * args.document_length)
+ for i in range(args.num_documents)]
+
+ # Create the LLM engine
+ engine_args = EngineArgs.from_cli_args(args)
+ llm = LLM(**dataclasses.asdict(engine_args))
+ sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
+
+ print("------warm up------")
+ test_long_document_qa(
+ llm=llm,
+ prompts=warmup_prompts,
+ sampling_params=sampling_params,
+ )
+
+ print("------start generating------")
+ test_long_document_qa(
+ llm=llm,
+ prompts=prompts,
+ sampling_params=sampling_params,
+ )
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(
+ description=
+ 'Benchmark the performance with or without automatic prefix caching.')
+
+ parser.add_argument(
+ '--document-length',
+ type=int,
+ # Roughly the number of tokens for a system paper,
+ # excluding images
+ default=20000,
+ help='Range of input lengths for sampling prompts,'
+ 'specified as "min:max" (e.g., "128:256").')
+
+ parser.add_argument('--num-documents',
+ type=int,
+ default=8,
+ help='Range of input lengths for sampling prompts,'
+ 'specified as "min:max" (e.g., "128:256").')
+
+ parser.add_argument('--output-len', type=int, default=10)
+
+ parser.add_argument('--repeat-count',
+ type=int,
+ default=2,
+ help='Number of times to repeat each prompt')
+
+ parser.add_argument("--repeat-mode",
+ type=str,
+ default='random',
+ help='The mode to repeat prompts. The supported '
+ 'modes are "random", "tile", and "interleave". '
+ 'See repeat_prompts() in the source code for details.')
+
+ parser.add_argument("--shuffle-seed",
+ type=int,
+ default=0,
+ help='Random seed when the repeat mode is "random"')
+
+ parser = EngineArgs.add_cli_args(parser)
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 1e5967bd9bf8b..c1b10b3cf8f58 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -4,7 +4,8 @@
import json
import random
import time
-from typing import List, Optional
+from functools import cache
+from typing import Dict, List, Optional, Tuple
import torch
import uvloop
@@ -17,8 +18,11 @@
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
+from vllm.lora.request import LoRARequest
+from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
+from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@@ -28,15 +32,17 @@ class SampleRequest:
Attributes:
prompt: The input text prompt for the model.
- multi_modal_data: Optional dictionary containing multi-modal data (e.g.
- images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
+ multi_modal_data: Optional dictionary containing multi-modal data (e.g.
+ images).
+ lora_request: Optional LoRARequest specifying the LoRA to use.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None
+ lora_request: Optional[LoRARequest] = None
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
@@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")
+@cache
+def lora_path_on_disk(lora_path: str) -> str:
+ return get_adapter_absolute_path(lora_path)
+
+
+lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
+
+
+def get_random_lora_request(
+ args: argparse.Namespace
+) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
+ global lora_tokenizer_cache
+ lora_id = random.randint(1, args.max_loras)
+ lora_request = LoRARequest(lora_name=str(lora_id),
+ lora_int_id=lora_id,
+ lora_path=lora_path_on_disk(args.lora_path))
+ if lora_id not in lora_tokenizer_cache:
+ lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
+ return lora_request, lora_tokenizer_cache[lora_id]
+
+
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
+
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
@@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
- for data in dataset:
+ for data in tqdm(dataset,
+ total=len(filtered_dataset),
+ desc="sampling requests"):
if len(filtered_dataset) == num_requests:
break
@@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)
+ request_tokenizer = tokenizer
+ lora_request: Optional[LoRARequest] = None
+ if args.enable_lora:
+ lora_request, lora_tokenizer = get_random_lora_request(args)
+ if lora_tokenizer:
+ request_tokenizer = lora_tokenizer
+
# Tokenize the prompts and completions.
- prompt_token_ids = tokenizer(prompt).input_ids
- completion_token_ids = tokenizer(completion).input_ids
+ prompt_token_ids = request_tokenizer(prompt).input_ids
+ completion_token_ids = request_tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
@@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
- multi_modal_data=multi_modal_data))
+ multi_modal_data=multi_modal_data,
+ lora_request=lora_request))
return filtered_dataset
@@ -146,14 +184,21 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
+ lora_requests: Optional[List[LoRARequest]] = None
+ if engine_args.enable_lora:
+ lora_requests = [request.lora_request for request in requests]
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
- llm.generate(prompts, sampling_params, use_tqdm=True)
+ llm.generate(prompts,
+ sampling_params,
+ lora_request=lora_requests,
+ use_tqdm=True)
end = time.perf_counter()
else:
+ assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
@@ -185,6 +230,7 @@ async def run_vllm_async(
# Add the requests to the engine.
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
+ lora_requests: List[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
@@ -197,11 +243,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
+ lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
- for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
- generator = llm.generate(prompt, sp, request_id=f"test{i}")
+ for i, (prompt, sp,
+ lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
+ generator = llm.generate(prompt,
+ sp,
+ lora_request=lr,
+ request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
@@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
vocab_size = tokenizer.vocab_size
requests = []
for _ in range(args.num_prompts):
+
+ request_tokenizer = tokenizer
+ lora_request: Optional[LoRARequest] = None
+ if args.enable_lora:
+ lora_request, lora_tokenizer = get_random_lora_request(args)
+ if lora_tokenizer:
+ request_tokenizer = lora_tokenizer
+
# Synthesize a prompt with the given input length.
candidate_ids = [
random.randint(0, vocab_size - 1)
@@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
- candidate_prompt = tokenizer.decode(candidate_ids)
- tokenized_len = len(tokenizer.encode(candidate_prompt))
+ candidate_prompt = request_tokenizer.decode(candidate_ids)
+ tokenized_len = len(request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len:
break
@@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
- expected_output_len=args.output_len))
+ expected_output_len=args.output_len,
+ lora_request=lora_request))
else:
requests = sample_requests(tokenizer, args)
@@ -422,6 +482,14 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
+ # LoRA
+ parser.add_argument(
+ "--lora-path",
+ type=str,
+ default=None,
+ help="Path to the lora adapters to use. This can be an absolute path, "
+ "a relative path, or a Hugging Face model identifier.")
+
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
@@ -431,6 +499,8 @@ def main(args: argparse.Namespace):
assert args.output_len is not None
else:
assert args.input_len is None
+ if args.enable_lora:
+ assert args.lora_path is not None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
@@ -440,6 +510,9 @@ def main(args: argparse.Namespace):
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
+ if args.enable_lora is not None:
+ raise ValueError("LoRA benchmarking is only supported for vLLM"
+ " backend")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
@@ -452,4 +525,7 @@ def main(args: argparse.Namespace):
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
+ if args.enable_lora is not None:
+ raise ValueError("LoRA benchmarking is only supported for vLLM"
+ " backend")
main(args)
diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
new file mode 100644
index 0000000000000..3d1c5e392f9e2
--- /dev/null
+++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
@@ -0,0 +1,384 @@
+import argparse
+import copy
+import itertools
+import pickle as pkl
+import time
+from typing import Callable, Iterable, List, Tuple
+
+import torch
+import torch.utils.benchmark as TBenchmark
+from torch.utils.benchmark import Measurement as TMeasurement
+from utils import make_rand_sparse_tensors
+from weight_shapes import WEIGHT_SHAPES
+
+from vllm import _custom_ops as ops
+from vllm.utils import FlexibleArgumentParser
+
+DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
+DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
+DEFAULT_TP_SIZES = [1]
+
+
+# bench
+def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
+ **kwargs) -> TMeasurement:
+ min_run_time = 1
+
+ globals = {
+ "args": args,
+ "kwargs": kwargs,
+ "fn": fn,
+ }
+ return TBenchmark.Timer(
+ stmt="fn(*args, **kwargs)",
+ globals=globals,
+ label=label,
+ sub_label=sub_label,
+ description=description,
+ ).blocked_autorange(min_run_time=min_run_time)
+
+
+def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
+ sub_label: str) -> Iterable[TMeasurement]:
+ assert dtype == torch.int8
+ b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
+ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+
+ out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
+ torch.bfloat16)
+ out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
+
+ if not torch.allclose(out, out_ref):
+ print("Incorrect results")
+ print(out)
+ print(out_ref)
+ else:
+ print("Correct results")
+
+ timers = []
+ # pytorch impl - bfloat16
+ timers.append(
+ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
+ torch.mm, a.to(dtype=torch.bfloat16),
+ b.to(dtype=torch.bfloat16)))
+
+ # pytorch impl - float16
+ timers.append(
+ bench_fn(label, sub_label,
+ "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
+ a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
+
+ # cutlass impl
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
+ ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
+ torch.bfloat16))
+
+ # cutlass with bias
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
+ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
+ bias))
+
+ # cutlass sparse impl
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16))
+
+ # cutlass sparse with bias
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16, bias))
+
+ return timers
+
+
+def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
+ sub_label: str) -> Iterable[TMeasurement]:
+ assert dtype == torch.float8_e4m3fn
+ b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
+ k)
+ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+
+ out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
+ torch.bfloat16)
+ out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
+
+ if not torch.allclose(out, out_ref):
+ print("Incorrect results")
+ print(out)
+ print(out_ref)
+ else:
+ print("Correct results")
+
+ timers = []
+
+ # pytorch impl w. bf16
+ timers.append(
+ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
+ torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
+ b.to(dtype=torch.bfloat16, device="cuda")))
+
+ # pytorch impl: bf16 output, without fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_bf16_scaled_mm",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.bfloat16))
+
+ # pytorch impl: bf16 output, with fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.bfloat16,
+ use_fast_accum=True))
+
+ # pytorch impl: fp16 output, without fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_fp16_scaled_mm",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.float16))
+
+ # pytorch impl: fp16 output, with fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.float16,
+ use_fast_accum=True))
+
+ # cutlass impl: bf16 output
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
+ ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
+ torch.bfloat16))
+
+ # cutlass impl: bf16 output
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16))
+
+ # cutlass impl: fp16 output
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.float16))
+
+ # cutlass impl: bf16 output, with bias
+ timers.append(
+ bench_fn(label, sub_label,
+ "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16, bias))
+
+ # cutlass impl: fp16 output, with bias
+ timers.append(
+ bench_fn(label, sub_label,
+ "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.float16, bias.to(dtype=torch.float16)))
+
+ return timers
+
+
+def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
+ sub_label: str) -> Iterable[TMeasurement]:
+ if dtype == torch.int8:
+ return bench_int8(dtype, m, k, n, label, sub_label)
+ if dtype == torch.float8_e4m3fn:
+ return bench_fp8(dtype, m, k, n, label, sub_label)
+ raise ValueError("unsupported type")
+
+
+# runner
+def print_timers(timers: Iterable[TMeasurement]):
+ compare = TBenchmark.Compare(timers)
+ compare.print()
+
+
+def run(dtype: torch.dtype,
+ MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
+ results = []
+ for m, k, n in MKNs:
+ timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
+ f"MKN=({m}x{k}x{n})")
+ print_timers(timers)
+ results.extend(timers)
+
+ return results
+
+
+# output makers
+def make_output(data: Iterable[TMeasurement],
+ MKNs: Iterable[Tuple[int, int, int]],
+ base_description: str,
+ timestamp=None):
+ print(f"== All Results {base_description} ====")
+ print_timers(data)
+
+ # pickle all the results
+ timestamp = int(time.time()) if timestamp is None else timestamp
+ with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
+ pkl.dump(data, f)
+
+
+# argparse runners
+
+
+def run_square_bench(args):
+ dim_sizes = list(
+ range(args.dim_start, args.dim_end + 1, args.dim_increment))
+ MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
+ data = run(args.dtype, MKNs)
+
+ make_output(data, MKNs, f"square_bench-{args.dtype}")
+
+
+def run_range_bench(args):
+ dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
+ n = len(dim_sizes)
+ Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
+ Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
+ Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
+ MKNs = list(zip(Ms, Ks, Ns))
+ data = run(args.dtype, MKNs)
+
+ make_output(data, MKNs, f"range_bench-{args.dtype}")
+
+
+def run_model_bench(args):
+ print("Benchmarking models:")
+ for i, model in enumerate(args.models):
+ print(f"[{i}] {model}")
+
+ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
+ KNs = []
+ for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
+ KN[tp_split_dim] = KN[tp_split_dim] // tp_size
+ KNs.append(KN)
+ return KNs
+
+ model_bench_data = []
+ models_tps = list(itertools.product(args.models, args.tp_sizes))
+ for model, tp_size in models_tps:
+ Ms = args.batch_sizes
+ KNs = model_shapes(model, tp_size)
+ MKNs = []
+ for m in Ms:
+ for k, n in KNs:
+ MKNs.append((m, k, n))
+
+ data = run(args.dtype, MKNs)
+ model_bench_data.append(data)
+
+ # Print all results
+ for data, model_tp in zip(model_bench_data, models_tps):
+ model, tp_size = model_tp
+ print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
+ print_timers(data)
+
+ timestamp = int(time.time())
+
+ all_data = []
+ for d in model_bench_data:
+ all_data.extend(d)
+ # pickle all data
+ with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
+ pkl.dump(all_data, f)
+
+
+if __name__ == '__main__':
+
+ def to_torch_dtype(dt):
+ if dt == "int8":
+ return torch.int8
+ if dt == "fp8":
+ return torch.float8_e4m3fn
+ raise ValueError("unsupported dtype")
+
+ parser = FlexibleArgumentParser(
+ description="""
+Benchmark Cutlass GEMM.
+
+ To run square GEMMs:
+ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
+
+ To run constant N and K and sweep M:
+ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
+
+ To run dimensions from a model:
+ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
+
+ Output:
+ - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
+ """, # noqa: E501
+ formatter_class=argparse.RawTextHelpFormatter)
+
+ parser.add_argument("--dtype",
+ type=to_torch_dtype,
+ required=True,
+ help="Available options are ['int8', 'fp8']")
+ subparsers = parser.add_subparsers(dest="cmd")
+
+ square_parser = subparsers.add_parser("square_bench")
+ square_parser.add_argument("--dim-start", type=int, required=True)
+ square_parser.add_argument("--dim-end", type=int, required=True)
+ square_parser.add_argument("--dim-increment", type=int, required=True)
+ square_parser.set_defaults(func=run_square_bench)
+
+ range_parser = subparsers.add_parser("range_bench")
+ range_parser.add_argument("--dim-start", type=int, required=True)
+ range_parser.add_argument("--dim-end", type=int, required=True)
+ range_parser.add_argument("--dim-increment", type=int, required=True)
+ range_parser.add_argument("--m-constant", type=int, default=None)
+ range_parser.add_argument("--n-constant", type=int, default=None)
+ range_parser.add_argument("--k-constant", type=int, default=None)
+ range_parser.set_defaults(func=run_range_bench)
+
+ model_parser = subparsers.add_parser("model_bench")
+ model_parser.add_argument("--models",
+ nargs="+",
+ type=str,
+ default=DEFAULT_MODELS,
+ choices=WEIGHT_SHAPES.keys())
+ model_parser.add_argument("--tp-sizes",
+ nargs="+",
+ type=int,
+ default=DEFAULT_TP_SIZES)
+ model_parser.add_argument("--batch-sizes",
+ nargs="+",
+ type=int,
+ default=DEFAULT_BATCH_SIZES)
+ model_parser.set_defaults(func=run_model_bench)
+
+ args = parser.parse_args()
+ args.func(args)
diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py
new file mode 100644
index 0000000000000..ef06fcd6604dd
--- /dev/null
+++ b/benchmarks/cutlass_benchmarks/utils.py
@@ -0,0 +1,96 @@
+# Cutlass bench utils
+from typing import Iterable, Tuple
+
+import torch
+
+import vllm._custom_ops as ops
+
+
+def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
+ finfo = torch.finfo(torch.float8_e4m3fn)
+ return torch.round(tensor.clamp(
+ min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
+
+
+def to_int8(tensor: torch.Tensor) -> torch.Tensor:
+ return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
+
+
+def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
+ return tensor.to(dtype=torch.bfloat16)
+
+
+def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
+ return tensor.to(dtype=torch.float16)
+
+
+def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
+ k: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ a = torch.randn((m, k), device='cuda') * 5
+ b = torch.randn((n, k), device='cuda').t() * 5
+
+ if dtype == torch.int8:
+ return to_int8(a), to_int8(b)
+ if dtype == torch.float8_e4m3fn:
+ return to_fp8(a), to_fp8(b)
+
+ raise ValueError("unsupported dtype")
+
+
+def prune_to_2_4(tensor):
+ # Reshape tensor to [N, 4] where N is number of groups of 4
+ original_shape = tensor.shape
+ reshaped = tensor.reshape(-1, 4)
+
+ # Get indices of top 2 absolute values in each group of 4
+ _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
+
+ # Create binary mask
+ mask = torch.zeros_like(reshaped)
+ mask.scatter_(dim=1,
+ index=indices,
+ src=torch.ones_like(indices, dtype=mask.dtype))
+
+ # Apply mask and reshape back
+ pruned = reshaped * mask
+
+ # Turn all -0.0 to 0.0
+ pruned[pruned == -0.0] = 0.0
+
+ return pruned.reshape(original_shape)
+
+
+def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
+ k: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ a = torch.randn((m, k), device='cuda') * 5
+ b = torch.randn((n, k), device='cuda').t() * 5
+
+ b = prune_to_2_4(b.t()).t()
+
+ if dtype == torch.int8:
+ a, b = to_int8(a), to_int8(b)
+ elif dtype == torch.float8_e4m3fn:
+ a, b = to_fp8(a), to_fp8(b)
+ elif dtype == torch.float16:
+ a, b = to_fp16(a), to_fp16(b)
+ elif dtype == torch.bfloat16:
+ a, b = to_bf16(a), to_bf16(b)
+ else:
+ raise ValueError("unsupported dtype")
+
+ b_compressed, e = ops.cutlass_sparse_compress(b.t())
+
+ # Compressed B, Metadata, Original A, B
+ return b_compressed, e, a, b
+
+
+def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
+ m: int, n: int, k: int) -> \
+ Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
+ ABs = []
+ for _ in range(num_tensors):
+ b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
+ if b_comp is not None:
+ ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
+ BComps, Es, As, Bs = zip(*ABs)
+ return list(BComps), list(Es), list(As), list(Bs)
diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
index 63cf5d50cac75..d0353bc8cb42a 100644
--- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
@@ -8,6 +8,7 @@
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
+from utils import make_rand_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
@@ -17,31 +18,6 @@
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
-# helpers
-
-
-def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
- finfo = torch.finfo(torch.float8_e4m3fn)
- return torch.round(tensor.clamp(
- min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
-
-
-def to_int8(tensor: torch.Tensor) -> torch.Tensor:
- return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
-
-
-def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
- k: int) -> Tuple[torch.Tensor, torch.Tensor]:
- a = torch.randn((m, k), device='cuda') * 5
- b = torch.randn((n, k), device='cuda').t() * 5
-
- if dtype == torch.int8:
- return to_int8(a), to_int8(b)
- if dtype == torch.float8_e4m3fn:
- return to_fp8(a), to_fp8(b)
-
- raise ValueError("unsupported dtype")
-
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
@@ -386,4 +362,4 @@ def to_torch_dtype(dt):
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
- args.func(args)
+ args.func(args)
\ No newline at end of file
diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py
index 25ec9d6028627..d58fb0bf86374 100644
--- a/benchmarks/cutlass_benchmarks/weight_shapes.py
+++ b/benchmarks/cutlass_benchmarks/weight_shapes.py
@@ -40,4 +40,4 @@
([8192, 57344], 1),
([28672, 8192], 0),
],
-}
+}
\ No newline at end of file
diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
index 2924ea4a49f54..94999630bae12 100644
--- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
+++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
@@ -10,7 +10,8 @@ set -ex
kill_gpu_processes() {
# kill all processes on GPU.
- pkill -f pt_main_thread
+ pgrep pt_main_thread | xargs -r kill -9
+ pgrep python3 | xargs -r kill -9
sleep 10
# remove vllm config file
@@ -54,7 +55,7 @@ benchmark() {
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
- --model meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --model $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
@@ -64,7 +65,7 @@ benchmark() {
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
- --model meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --model $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
@@ -87,7 +88,7 @@ benchmark() {
--port 8100 \
--save-result \
--result-dir $results_folder \
- --result-filename disagg_prefill_2xtp4.json \
+ --result-filename disagg_prefill_tp1.json \
--request-rate "inf"
@@ -105,7 +106,7 @@ benchmark() {
--port 8200 \
--save-result \
--result-dir $results_folder \
- --result-filename disagg_prefill_2xtp4.json \
+ --result-filename disagg_prefill_tp1_overhead.json \
--request-rate "$qps"
kill_gpu_processes
@@ -118,7 +119,7 @@ main() {
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
- pip install quart httpx
+ pip install quart httpx datasets
cd "$(dirname "$0")"
diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
index d8d9e976dce76..eb5d891d0d4a5 100644
--- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
+++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
@@ -1,13 +1,12 @@
#!/bin/bash
-# Requirement: 8x H100 GPUs.
+# Requirement: 2x GPUs.
-# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
-# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
-# Resource: 8x H100
+# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
+# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
+# Resource: 2x GPU
# Approaches:
-# 1. Chunked prefill: 1 vllm instance with tp=8
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
# Prefilling instance: max_output_token=1
@@ -114,7 +113,6 @@ benchmark() {
--request-rate "$qps"
sleep 2
-
}
@@ -123,8 +121,9 @@ main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
+ (which lsof) || (apt-get -y install lsof)
- pip install quart httpx matplotlib aiohttp
+ pip install quart httpx matplotlib aiohttp datasets
cd "$(dirname "$0")"
diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py
new file mode 100644
index 0000000000000..baa5de0fff1bd
--- /dev/null
+++ b/benchmarks/kernels/benchmark_rmsnorm.py
@@ -0,0 +1,262 @@
+import itertools
+from typing import Optional, Tuple, Union
+
+import torch
+import triton
+from flashinfer.norm import fused_add_rmsnorm, rmsnorm
+from torch import nn
+
+from vllm import _custom_ops as vllm_ops
+
+
+class HuggingFaceRMSNorm(nn.Module):
+
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ orig_dtype = x.dtype
+ x = x.to(torch.float32)
+ if residual is not None:
+ x = x + residual.to(torch.float32)
+ residual = x.to(orig_dtype)
+
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
+ x = x.to(orig_dtype) * self.weight
+ if residual is None:
+ return x
+ else:
+ return x, residual
+
+
+def rmsnorm_naive(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ eps: float = 1e-6,
+):
+ naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
+ naive_norm.weight = nn.Parameter(weight)
+ naive_norm = naive_norm.to(x.device)
+
+ orig_shape = x.shape
+ x = x.view(-1, x.shape[-1])
+ if residual is not None:
+ residual = residual.view(-1, residual.shape[-1])
+
+ output = naive_norm(x, residual)
+
+ if isinstance(output, tuple):
+ output = (output[0].view(orig_shape), output[1].view(orig_shape))
+ else:
+ output = output.view(orig_shape)
+ return output
+
+
+def rmsnorm_flashinfer(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ eps: float = 1e-6,
+):
+ orig_shape = x.shape
+ x = x.view(-1, x.shape[-1])
+ if residual is not None:
+ residual = residual.view(-1, residual.shape[-1])
+
+ if residual is not None:
+ fused_add_rmsnorm(x, residual, weight, eps)
+ output = (x, residual)
+ else:
+ output = rmsnorm(x, weight, eps)
+
+ if isinstance(output, tuple):
+ output = (output[0].view(orig_shape), output[1].view(orig_shape))
+ else:
+ output = output.view(orig_shape)
+ return output
+
+
+def rmsnorm_vllm(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ eps: float = 1e-6,
+):
+ orig_shape = x.shape
+ x = x.view(-1, x.shape[-1])
+ if residual is not None:
+ residual = residual.view(-1, residual.shape[-1])
+
+ if residual is not None:
+ vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
+ output = (x, residual)
+ else:
+ out = torch.empty_like(x)
+ vllm_ops.rms_norm(out, x, weight, eps)
+ output = out
+
+ if isinstance(output, tuple):
+ output = (output[0].view(orig_shape), output[1].view(orig_shape))
+ else:
+ output = output.view(orig_shape)
+ return output
+
+
+def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
+ dtype = torch.bfloat16
+ x = torch.randn(batch_size,
+ seq_len,
+ hidden_size,
+ dtype=dtype,
+ device="cuda")
+ weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
+ residual = torch.randn_like(x) if use_residual else None
+
+ output_naive = rmsnorm_naive(
+ x.clone(), weight,
+ residual.clone() if residual is not None else None)
+ output_flashinfer = rmsnorm_flashinfer(
+ x.clone(), weight,
+ residual.clone() if residual is not None else None)
+ output_vllm = rmsnorm_vllm(
+ x.clone(), weight,
+ residual.clone() if residual is not None else None)
+
+ if use_residual:
+ output_naive = output_naive[0]
+ output_flashinfer = output_flashinfer[0]
+ output_vllm = output_vllm[0]
+
+ print(f"Naive output={output_naive}")
+ print(f"FlashInfer output={output_flashinfer}")
+ print(f"VLLM output={output_vllm}")
+
+ if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
+ rtol=1e-2) and torch.allclose(
+ output_naive, output_vllm, atol=1e-2, rtol=1e-2):
+ print("✅ All implementations match")
+ else:
+ print("❌ Implementations differ")
+
+
+batch_size_range = [2**i for i in range(0, 7, 2)]
+seq_length_range = [2**i for i in range(6, 11, 1)]
+head_num_range = [32, 48]
+configs = list(
+ itertools.product(head_num_range, batch_size_range, seq_length_range))
+
+
+def get_benchmark(use_residual):
+
+ @triton.testing.perf_report(
+ triton.testing.Benchmark(
+ x_names=["head_num", "batch_size", "seq_len"],
+ x_vals=[list(_) for _ in configs],
+ line_arg="provider",
+ line_vals=["huggingface", "flashinfer", "vllm"],
+ line_names=["HuggingFace", "FlashInfer", "vLLM"],
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
+ ylabel="us",
+ plot_name=
+ f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
+ args={},
+ ))
+ def benchmark(head_num, batch_size, seq_len, provider):
+ dtype = torch.bfloat16
+ hidden_size = head_num * 128 # assuming head_dim = 128
+
+ x = torch.randn(batch_size,
+ seq_len,
+ hidden_size,
+ dtype=dtype,
+ device="cuda")
+ weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
+ residual = torch.randn_like(x) if use_residual else None
+
+ quantiles = [0.5, 0.2, 0.8]
+
+ if provider == "huggingface":
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: rmsnorm_naive(
+ x.clone(),
+ weight,
+ residual.clone() if residual is not None else None,
+ ),
+ quantiles=quantiles,
+ )
+ elif provider == "flashinfer":
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: rmsnorm_flashinfer(
+ x.clone(),
+ weight,
+ residual.clone() if residual is not None else None,
+ ),
+ quantiles=quantiles,
+ )
+ else:
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: rmsnorm_vllm(
+ x.clone(),
+ weight,
+ residual.clone() if residual is not None else None,
+ ),
+ quantiles=quantiles,
+ )
+
+ return 1000 * ms, 1000 * max_ms, 1000 * min_ms
+
+ return benchmark
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=4,
+ help="Batch size",
+ )
+ parser.add_argument(
+ "--seq-len",
+ type=int,
+ default=128,
+ help="Sequence length",
+ )
+ parser.add_argument(
+ "--hidden-size",
+ type=int,
+ default=4096,
+ help="Hidden size (2nd dimension) of the sequence",
+ )
+ parser.add_argument("--use-residual",
+ action="store_true",
+ help="Whether to use residual connection")
+ parser.add_argument(
+ "--save-path",
+ type=str,
+ default="./configs/rmsnorm/",
+ help="Path to save rmsnorm benchmark results",
+ )
+
+ args = parser.parse_args()
+
+ # Run correctness test
+ calculate_diff(batch_size=args.batch_size,
+ seq_len=args.seq_len,
+ hidden_size=args.hidden_size,
+ use_residual=args.use_residual)
+
+ # Get the benchmark function with proper use_residual setting
+ benchmark = get_benchmark(args.use_residual)
+ # Run performance benchmark
+ benchmark.run(print_data=True, save_path=args.save_path)
diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu
index 5cdd250c3f9cf..3569b3c88abcd 100644
--- a/csrc/attention/paged_attention_v1.cu
+++ b/csrc/attention/paged_attention_v1.cu
@@ -53,7 +53,7 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional& alibi_slopes, torch::Tensor& k_scale,
+ const std::optional& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
@@ -194,7 +194,7 @@ void paged_attention_v1(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
- const c10::optional& alibi_slopes,
+ const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu
index c0e6b7cfd67a0..bc543e713fe58 100644
--- a/csrc/attention/paged_attention_v2.cu
+++ b/csrc/attention/paged_attention_v2.cu
@@ -62,7 +62,7 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional& alibi_slopes, torch::Tensor& k_scale,
+ const std::optional& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
@@ -213,7 +213,7 @@ void paged_attention_v2(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
- const c10::optional& alibi_slopes,
+ const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp
new file mode 100644
index 0000000000000..ba9f40a230c8e
--- /dev/null
+++ b/csrc/core/math.hpp
@@ -0,0 +1,7 @@
+#include
+#include
+
+inline uint32_t next_pow_2(uint32_t const num) {
+ if (num <= 1) return num;
+ return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
+}
\ No newline at end of file
diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp
index e21832ba7582f..ef5b14088c63b 100644
--- a/csrc/cpu/attention.cpp
+++ b/csrc/cpu/attention.cpp
@@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional& alibi_slopes) {
+ const std::optional& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -459,7 +459,7 @@ void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
- int64_t max_seq_len, const c10::optional& alibi_slopes,
+ int64_t max_seq_len, const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
@@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
- int max_seq_len, const c10::optional& alibi_slopes) {
+ int max_seq_len, const std::optional& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -781,7 +781,7 @@ void paged_attention_v2(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
- int64_t max_seq_len, const c10::optional& alibi_slopes,
+ int64_t max_seq_len, const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp
index d9aed657a3113..33b1637832888 100644
--- a/csrc/cpu/quant.cpp
+++ b/csrc/cpu/quant.cpp
@@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& b, // [IC, OC], column-major
const torch::Tensor& a_scales, // [1] or [M]
const torch::Tensor& b_scales, // [1] or [OC]
- const c10::optional& bias // [OC]
+ const std::optional& bias // [OC]
) {
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
// Checks for conformality
@@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a_scales, // [1] or [M]
const torch::Tensor& b_scales, // [1] or [OC]
const torch::Tensor& azp_adj, // [OC]
- const c10::optional& azp, // [1] or [M]
- const c10::optional& bias // [OC]
+ const std::optional& azp, // [1] or [M]
+ const std::optional& bias // [OC]
) {
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
// Checks for conformality
@@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const torch::Tensor& input, // [..., hidden_size]
const torch::Tensor& scale,
- c10::optional const& azp) {
+ std::optional const& azp) {
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
@@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
const torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale, // [..., 1]
- c10::optional const& azp) {
+ std::optional const& azp) {
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index 03beefbc6de7d..74e4d8189d403 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids);
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
- const c10::optional& bias);
+ const std::optional& bias);
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const torch::Tensor& azp_adj,
- const c10::optional& azp,
- const c10::optional& bias);
+ const std::optional& azp,
+ const std::optional& bias);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp
new file mode 100644
index 0000000000000..3d2093ab94297
--- /dev/null
+++ b/csrc/cutlass_extensions/common.cpp
@@ -0,0 +1,11 @@
+#include "cutlass_extensions/common.hpp"
+
+int32_t get_sm_version_num() {
+ int32_t major_capability, minor_capability;
+ cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
+ 0);
+ cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
+ 0);
+ int32_t version_num = major_capability * 10 + minor_capability;
+ return version_num;
+}
\ No newline at end of file
diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp
new file mode 100644
index 0000000000000..85e359aa57113
--- /dev/null
+++ b/csrc/cutlass_extensions/common.hpp
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include
+#include "cuda_runtime.h"
+#include
+
+/**
+ * Helper function for checking CUTLASS errors
+ */
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TORCH_CHECK(error == cutlass::Status::kSuccess, \
+ cutlassGetStatusString(error)); \
+ }
+
+/**
+ * Panic wrapper for unwinding CUDA runtime errors
+ */
+#define CUDA_CHECK(status) \
+ { \
+ cudaError_t error = status; \
+ TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
+ }
+
+inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
+ int max_shared_mem_per_block_opt_in = 0;
+ cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
+ device);
+ return max_shared_mem_per_block_opt_in;
+}
+
+int32_t get_sm_version_num();
diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
index c69e87999ae71..ef413e6dd75c5 100644
--- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
@@ -1,3 +1,5 @@
+#pragma once
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
@@ -66,7 +68,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
- static auto args_from_tensor(c10::optional const& tensor) {
+ static auto args_from_tensor(std::optional const& tensor) {
static_assert(std::is_same_v>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
@@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
- c10::optional const& bias) {
+ std::optional const& bias) {
auto a_args = SUPER::template args_from_tensor(a_scales);
auto b_args = SUPER::template args_from_tensor(b_scales);
auto bias_args = SUPER::template args_from_tensor(bias);
@@ -299,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
- c10::optional const& bias) {
+ std::optional const& bias) {
auto a_args = SUPER::template args_from_tensor(a_scales);
auto b_args = SUPER::template args_from_tensor(b_scales);
auto bias_args = SUPER::template args_from_tensor(bias);
diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
index 95764ecddc79f..c590c66a66652 100644
--- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
@@ -1,3 +1,5 @@
+#pragma once
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
template
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>;
// Don't want to support nullptr by default
template
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
@@ -65,7 +67,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
- static auto args_from_tensor(c10::optional const& tensor) {
+ static auto args_from_tensor(std::optional const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v> ||
@@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
- c10::optional const& bias) {
+ std::optional const& bias) {
auto a_args = SUPER::template args_from_tensor(a_scales);
auto b_args = SUPER::template args_from_tensor(b_scales);
auto bias_args = SUPER::template args_from_tensor(bias);
@@ -297,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
- c10::optional const& bias) {
+ std::optional const& bias) {
auto a_args = SUPER::template args_from_tensor(a_scales);
auto b_args = SUPER::template args_from_tensor(b_scales);
auto bias_args = SUPER::template args_from_tensor(bias);
diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp
index 2c78572521eec..a1ff933cce63f 100644
--- a/csrc/cutlass_extensions/torch_utils.hpp
+++ b/csrc/cutlass_extensions/torch_utils.hpp
@@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template
static inline auto maybe_make_cute_layout(
- c10::optional const& tensor,
+ std::optional const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout(*tensor));
diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py
index a5beea1a35e49..b401736c9824b 100644
--- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py
@@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
class MixedInputKernelScheduleType(enum.Enum):
- TmaWarpSpecializedMixedInput = enum_auto()
- TmaWarpSpecializedPingpongMixedInput = enum_auto()
- TmaWarpSpecializedCooperativeMixedInput = enum_auto()
+ TmaWarpSpecialized = enum_auto()
+ TmaWarpSpecializedPingpong = enum_auto()
+ TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
@@ -68,11 +68,11 @@ class MixedInputKernelScheduleType(enum.Enum):
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
- MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
- "cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
- MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
- "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
- MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
- "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
+ MixedInputKernelScheduleType.TmaWarpSpecialized:
+ "cutlass::gemm::KernelTmaWarpSpecialized",
+ MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
+ "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
+ MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
+ "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
}
}
diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu
index dd1e6de2e0180..f0e5533bcae60 100644
--- a/csrc/mamba/causal_conv1d/causal_conv1d.cu
+++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu
@@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
const at::Tensor x,
const at::Tensor weight,
const at::Tensor out,
- const c10::optional& bias,
+ const std::optional& bias,
bool silu_activation,
int64_t pad_slot_id,
- const c10::optional& query_start_loc = std::nullopt,
- const c10::optional& cache_indices = std::nullopt,
- const c10::optional& has_initial_state = std::nullopt) {
+ const std::optional& query_start_loc = std::nullopt,
+ const std::optional& cache_indices = std::nullopt,
+ const std::optional& has_initial_state = std::nullopt) {
// Reset the parameters
memset(¶ms, 0, sizeof(params));
@@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
- const c10::optional &bias_,
- const c10::optional &conv_states,
- const c10::optional &query_start_loc,
- const c10::optional &cache_indices,
- const c10::optional &has_initial_state,
+ const std::optional &bias_,
+ const std::optional &conv_states,
+ const std::optional &query_start_loc,
+ const std::optional &cache_indices,
+ const std::optional &has_initial_state,
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
@@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
void causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
- const c10::optional &bias_,
+ const std::optional &bias_,
bool silu_activation,
- const c10::optional &cache_seqlens_,
- const c10::optional &conv_state_indices_,
+ const std::optional &cache_seqlens_,
+ const std::optional &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
index 71624696338d0..bd0a34119c82b 100644
--- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
@@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const torch::Tensor out,
const torch::Tensor z,
const torch::Tensor out_z,
- const c10::optional& D,
- const c10::optional& delta_bias,
+ const std::optional& D,
+ const std::optional& delta_bias,
const torch::Tensor ssm_states,
bool has_z,
bool delta_softplus,
- const c10::optional& query_start_loc,
- const c10::optional& cache_indices,
- const c10::optional& has_initial_state,
+ const std::optional& query_start_loc,
+ const std::optional& cache_indices,
+ const std::optional& has_initial_state,
bool varlen,
int64_t pad_slot_id) {
@@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
- const c10::optional &D_,
- const c10::optional &z_,
- const c10::optional &delta_bias_,
+ const std::optional &D_,
+ const std::optional &z_,
+ const std::optional &delta_bias_,
bool delta_softplus,
- const c10::optional &query_start_loc,
- const c10::optional &cache_indices,
- const c10::optional &has_initial_state,
+ const std::optional &query_start_loc,
+ const std::optional &cache_indices,
+ const std::optional &has_initial_state,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu
index dd90c38d9a721..16fccae403338 100644
--- a/csrc/moe/moe_align_sum_kernels.cu
+++ b/csrc/moe/moe_align_sum_kernels.cu
@@ -112,6 +112,91 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}
+// TODO(simon): this is temporarily adapted from
+// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
+// we did this to unblock Deepseek V3 but there should be a better
+// implementation to manage shared memory.
+template
+__global__ void moe_align_block_size_global_mem_kernel(
+ scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
+ int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
+ int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
+ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
+ const size_t start_idx = threadIdx.x * tokens_per_thread;
+
+ for (int i = 0; i < num_experts; ++i) {
+ tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
+ }
+
+ /**
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
+ * which counts how many tokens in the token shard of thread_index are
+ * assigned to expert expert_index.
+ */
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
+ ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
+ }
+
+ __syncthreads();
+
+ // For each expert we accumulate the token counts from the different threads.
+ for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
+ tokens_cnts[index(num_experts, 0, eid)] = 0;
+ for (int i = 1; i <= blockDim.x; ++i) {
+ tokens_cnts[index(num_experts, i, eid)] +=
+ tokens_cnts[index(num_experts, i - 1, eid)];
+ }
+ }
+
+ __syncthreads();
+
+ // We accumulate the token counts of all experts in thread 0.
+ if (threadIdx.x == 0) {
+ cumsum[0] = 0;
+ for (int i = 1; i <= num_experts; ++i) {
+ cumsum[i] = cumsum[i - 1] +
+ CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
+ block_size) *
+ block_size;
+ }
+ *total_tokens_post_pad = cumsum[num_experts];
+ }
+
+ __syncthreads();
+
+ /**
+ * For each expert, each thread processes the tokens of the corresponding
+ * blocks and stores the corresponding expert_id for each block.
+ */
+ for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
+ for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
+ expert_ids[i / block_size] = eid;
+ }
+ }
+
+ /**
+ * Each thread processes a token shard, calculating the index of each token
+ * after sorting by expert number. Given the example topk_ids =
+ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
+ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
+ * padding value(preset in python).
+ */
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
+ int32_t expert_id = topk_ids[i];
+ /** The cumsum[expert_id] stores the starting index of the tokens that the
+ * expert with expert_id needs to process, and
+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
+ * processed by the expert with expert_id within the current thread's token
+ * shard.
+ */
+ int32_t rank_post_pad =
+ tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
+ cumsum[expert_id];
+ sorted_token_ids[rank_post_pad] = i;
+ ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
+ }
+}
+
template
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
diff --git a/csrc/ops.h b/csrc/ops.h
index 8ca912ff58897..e9cc8d2e215e2 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -33,7 +33,7 @@ void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
- int64_t max_seq_len, const c10::optional& alibi_slopes,
+ int64_t max_seq_len, const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
@@ -45,7 +45,7 @@ void paged_attention_v2(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
- int64_t max_seq_len, const c10::optional& alibi_slopes,
+ int64_t max_seq_len, const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
@@ -158,24 +158,35 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
- c10::optional const& bias);
+ std::optional const& bias);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
- c10::optional const& azp,
- c10::optional const& bias);
+ std::optional const& azp,
+ std::optional const& bias);
+
+bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
+
+void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
+ torch::Tensor const& b, torch::Tensor const& e,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ std::optional const& bias);
+
+bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
+ torch::Tensor& e, torch::Tensor const& a);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
- c10::optional const& azp);
+ std::optional const& azp);
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales,
- c10::optional const& azp);
+ std::optional const& azp);
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
@@ -192,34 +203,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
- c10::optional const& scale_ub);
+ std::optional const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
- const c10::optional& D_,
- const c10::optional& z_,
- const c10::optional& delta_bias_,
+ const std::optional& D_,
+ const std::optional& z_,
+ const std::optional& delta_bias_,
bool delta_softplus,
- const c10::optional& query_start_loc,
- const c10::optional& cache_indices,
- const c10::optional& has_initial_state,
+ const std::optional& query_start_loc,
+ const std::optional& cache_indices,
+ const std::optional& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
- const c10::optional& bias_,
+ const std::optional& bias_,
bool silu_activation,
- const c10::optional& cache_seqlens_,
- const c10::optional& conv_state_indices_,
+ const std::optional& cache_seqlens_,
+ const std::optional& conv_state_indices_,
int64_t pad_slot_id);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
- const c10::optional& bias_,
- const c10::optional& conv_states,
- const c10::optional& query_start_loc,
- const c10::optional& cache_indices,
- const c10::optional& has_initial_state,
+ const std::optional& bias_,
+ const std::optional& conv_states,
+ const std::optional& query_start_loc,
+ const std::optional& cache_indices,
+ const std::optional& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
using fptr_t = int64_t;
diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
index e9987535bd3ea..e79785827189d 100644
--- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
@@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale,
- c10::optional const& azp) {
+ std::optional const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
@@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
- torch::Tensor& scales, c10::optional const& azp) {
+ torch::Tensor& scales, std::optional const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scales.is_contiguous());
diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp
deleted file mode 100644
index bf04bb400790f..0000000000000
--- a/csrc/quantization/cutlass_w8a8/common.hpp
+++ /dev/null
@@ -1,27 +0,0 @@
-#pragma once
-
-#include "cutlass/cutlass.h"
-#include
-
-/**
- * Helper function for checking CUTLASS errors
- */
-#define CUTLASS_CHECK(status) \
- { \
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
- cutlassGetStatusString(status)) \
- }
-
-inline uint32_t next_pow_2(uint32_t const num) {
- if (num <= 1) return num;
- return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
-}
-
-inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
- int max_shared_mem_per_block_opt_in = 0;
- cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
- cudaDevAttrMaxSharedMemoryPerBlockOptin,
- device);
- return max_shared_mem_per_block_opt_in;
-}
-
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
index dbb72e8bbd3f5..865fef5aeea11 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
- c10::optional const& bias) {
+ std::optional const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
- c10::optional const& azp,
- c10::optional const& bias) {
+ std::optional const& azp,
+ std::optional const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
- c10::optional const& bias) {
+ std::optional const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
- c10::optional const& azp,
- c10::optional const& bias) {
+ std::optional const& azp,
+ std::optional const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
- c10::optional const& bias) {
+ std::optional const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
- c10::optional const& azp,
- c10::optional const& bias) {
+ std::optional const& azp,
+ std::optional const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
index d03242f44ab1d..f2fae4b66d651 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
@@ -21,15 +21,16 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
-#include "common.hpp"
+#include "core/math.hpp"
+#include "cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;
/*
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm80EVT,
+ Epilogues defined in,
+ csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+ must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
index 33581a63d4c3d..e18d7d79e5b77 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
@@ -1,384 +1,18 @@
-// clang-format will break include orders
-// clang-format off
#include
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
-#include
+ #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
+ #include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
-#include
-
-#include
-#include
-#include
-
-#include "cutlass/cutlass.h"
-
-#include "cute/tensor.hpp"
-#include "cute/atom/mma_atom.hpp"
-#include "cutlass/numeric_types.h"
-
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-#include "cutlass/gemm/kernel/gemm_universal.hpp"
-#include "cutlass/epilogue/collective/collective_builder.hpp"
-#include "cutlass/gemm/collective/collective_builder.hpp"
-
-#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
-#include "common.hpp"
-// clang-format on
-
-using namespace cute;
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
-
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm90EVT,
- as well as a static prepare_args function that constructs an
- EVTCompute::Arguments struct.
*/
-namespace {
-
-// A wrapper for the GEMM kernel that is used to guard against compilation on
-// architectures that will never use the kernel. The purpose of this is to
-// reduce the size of the compiled binary.
-// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
-// into code that will be executed on the device where it is defined.
-template
-struct enable_sm90_or_later : Kernel {
- template
- CUTLASS_DEVICE void operator()(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
- Kernel::operator()(std::forward(args)...);
- #endif
- }
-};
-template typename Epilogue_,
- typename TileShape, typename ClusterShape, typename KernelSchedule,
- typename EpilogueSchedule>
-struct cutlass_3x_gemm {
- using ElementAB = ElementAB_;
- using ElementD = ElementD_;
- using ElementAcc =
- typename std::conditional, int32_t,
- float>::type;
-
- using EpilogueDescriptor =
- cutlass::epilogue::collective::detail::EpilogueDescriptor<
- TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
- ElementD, EpilogueSchedule>;
-
- using Epilogue = Epilogue_;
-
- using StrideD = Stride, Int<0>>;
- using ElementC = void;
- using StrideC = StrideD;
-
- using EVTCompute = typename Epilogue::EVTCompute;
-
- using CollectiveEpilogue =
- typename cutlass::epilogue::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
- ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
- ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
- EpilogueSchedule, EVTCompute>::CollectiveOp;
-
- static constexpr size_t CEStorageSize =
- sizeof(typename CollectiveEpilogue::SharedStorage);
- using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
- static_cast(CEStorageSize)>;
-
- // clang-format off
- using CollectiveMainloop =
- typename cutlass::gemm::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
- ElementAB, cutlass::layout::RowMajor, 16,
- ElementAB, cutlass::layout::ColumnMajor, 16,
- ElementAcc, TileShape, ClusterShape,
- Stages,
- KernelSchedule>::CollectiveOp;
- // clang-format on
-
- using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue,
- cutlass::gemm::PersistentScheduler>>;
-
- struct GemmKernel : public KernelType {};
-};
-
-template
-void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_params) {
- using ElementAB = typename Gemm::ElementAB;
- using ElementD = typename Gemm::ElementD;
-
- int32_t m = a.size(0);
- int32_t n = b.size(1);
- int32_t k = a.size(1);
-
- int64_t lda = a.stride(0);
- int64_t ldb = b.stride(1);
- int64_t ldc = out.stride(0);
-
- using StrideA = Stride, int64_t>;
- using StrideB = Stride, int64_t>;
- using StrideC = typename Gemm::StrideC;
-
- StrideA a_stride{lda, Int<1>{}, 0};
- StrideB b_stride{ldb, Int<1>{}, 0};
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
-
- using GemmKernel = typename Gemm::GemmKernel;
- typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
-
- auto a_ptr = static_cast(a.data_ptr());
- auto b_ptr = static_cast(b.data_ptr());
- typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
- b_stride};
-
- auto c_ptr = static_cast(out.data_ptr());
- typename GemmKernel::EpilogueArguments epilogue_args{
- Gemm::Epilogue::prepare_args(
- std::forward(epilogue_params)...),
- c_ptr, c_stride, c_ptr, c_stride};
-
- typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
- prob_shape, mainloop_args, epilogue_args};
-
- // Launch the CUTLASS GEMM kernel.
- using GemmOp = cutlass::gemm::device::GemmUniversalAdapter;
- GemmOp gemm_op;
- CUTLASS_CHECK(gemm_op.can_implement(args));
-
- size_t workspace_size = gemm_op.get_workspace_size(args);
- auto const workspace_options =
- torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
- auto workspace = torch::empty(workspace_size, workspace_options);
-
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
-
- cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
- CUTLASS_CHECK(status);
-}
-
-template typename Epilogue>
-struct sm90_fp8_config_default {
- // M in (128, inf)
- static_assert(std::is_same());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_fp8_config_M128 {
- // M in (64, 128]
- static_assert(std::is_same());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_fp8_config_M64 {
- // M in [1, 64]
- static_assert(std::is_same());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _128>;
- using ClusterShape = Shape<_1, _8, _1>;
-
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_default {
- // For M > 128 and any N
- static_assert(std::is_same());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M128 {
- // For M in (64, 128] and any N
- static_assert(std::is_same());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M64 {
- // For M in (32, 64] and any N
- static_assert(std::is_same());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M32_NBig {
- // For M in [1, 32] and N >= 8192
- static_assert(std::is_same());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _256>;
- using ClusterShape = Shape<_1, _4, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M32_NSmall {
- // For M in [1, 32] and N < 8192
- static_assert(std::is_same());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _8, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-} // namespace
-
-template typename Epilogue,
- typename... EpilogueArgs>
-void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
-
- using Cutlass3xGemmDefault =
- typename sm90_fp8_config_default::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_fp8_config_M64::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_fp8_config_M128::Cutlass3xGemm;
-
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast(64), next_pow_2(m)); // next power of 2
-
- if (mp2 <= 64) {
- // m in [1, 64]
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- }
-}
-
-template typename Epilogue,
- typename... EpilogueArgs>
-void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same());
- TORCH_CHECK(a.dtype() == torch::kInt8);
- TORCH_CHECK(b.dtype() == torch::kInt8);
-
- using Cutlass3xGemmDefault =
- typename sm90_int8_config_default