Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] add triton fused moe kernel for gptq/awq #12185

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

jinzhen-lin
Copy link
Contributor

@jinzhen-lin jinzhen-lin commented Jan 18, 2025

The current only option for using moe+gptq/awq is the Marlin kernel, but for the Marlin kernel, a single marlin_gemm_moe would launching num_experts CUDA kernels at least, while the fused_moe triton kernel only needs to launch one cuda kernel. This makes the Marlin kernel significantly slower than the fused_moe triton kernel.

This PR adds support for fused_moe triton kernel with gptq/awq.

Generation speed of deepseek-v3-awq (8*A100-SXM4-80GB, bs=1, short prompt)

marlin moe kernel triton fused moe kernel
w/o #12222 5.4tok/s 10.0tok/s
w/ #12222 11.1tok/s 29.6 tok/s

Note:

  1. [Kernel] optimize moe_align_block_size for cuda graph and large num_experts (e.g. DeepSeek-V3) #12222 enable cuda graph and add shared memory moe_align_block_size kernel support for deepseek-v3
  2. to enable this kernel
python -m vllm.entrypoints.openai.api_server \
    --served-model-name model \
    --model cognitivecomputations/DeepSeek-V3-AWQ \
    --tensor-parallel-size 8 \
    --trust-remote-code \
    --max-model-len 24576 \
    --dtype half \
    --max-num-seqs 16 \
    --gpu-memory-utilization 0.96 \
    --quantization moe_quant_int

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin jinzhen-lin force-pushed the triton_fused_moe_int4 branch from 91b41c6 to 87e191f Compare January 18, 2025 12:08
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin jinzhen-lin force-pushed the triton_fused_moe_int4 branch from 21c1d8d to 99f23f2 Compare January 18, 2025 12:14
@mgoin mgoin self-requested a review January 18, 2025 21:47
Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin jinzhen-lin force-pushed the triton_fused_moe_int4 branch from 55102d9 to 15ae02b Compare January 19, 2025 06:44
@casper-hansen
Copy link
Contributor

@mgoin @robertgshaw2-redhat Could we expedite this PR + #12036 (not sure if #12204 is needed too or has overlap) now that DeepSeek has released their full lineup?

@jinzhen-lin
Copy link
Contributor Author

@mgoin @robertgshaw2-redhat Could we expedite this PR + #12036 (not sure if #12204 is needed too or has overlap) now that DeepSeek has released their full lineup?

I created a new PR with better moe_align_block_size just now, you can take a look at it. #12222

@casper-hansen
Copy link
Contributor

I think this PR could be closed in favor of #12222. Thanks for your work @jinzhen-lin

@jinzhen-lin
Copy link
Contributor Author

I think this PR could be closed in favor of #12222. Thanks for your work @jinzhen-lin

#12222 is an optimiztion over #12036 or #12204, it can be combined with this PR to get a better performance.

@mgoin
Copy link
Member

mgoin commented Jan 20, 2025

Thank you for the work! We will take a look now

top_k: tl.constexpr,
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w8a16: tl.constexpr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these should be renamed to use_int4_w4a16

Comment on lines 19 to 20
class MoeQuantIntConfig(QuantizationConfig):
"""Config class for Int8 experts quantization."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this comment

Is there any more specific name we could use for this method? I also feel that --quantization moe_quant_int is not clear. Maybe you could change to --quantization moe_wNa16 and MoeWNA16Config? Open to other names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moe_wNa16 is a better name, I would change it.

Comment on lines 101 to 116
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
if self.linear_quant_method == "gptq":
gptq_config = GPTQMarlinConfig.from_config(self.full_config)
return GPTQMarlinLinearMethod(gptq_config)
elif self.linear_quant_method == "awq":
awq_config = AWQMarlinConfig.from_config(self.full_config)
return AWQMarlinLinearMethod(awq_config)
else:
raise ValueError("moe_quant_int only support gptq and awq.")
elif isinstance(layer, FusedMoE):
return MoeQuantIntMethod(self)
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting hack - I wonder if we could just enable the MoeQuantIntMethod as a condition inside of the other quantization methods rather than duplicating them here in this config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this before, but I created a new quantization method finally. The reasons are

  1. This quantization method can be combined all gptq/awq quanzation methods, should we add it to all quanzaition methods that supported gptq/awq, or just gptq-marlin/awq-marlin ?
  2. This quantization method and triton kernel use a different weight format, it is just compatible with gptq/awq, and accept gpt/awq weight.
  3. Make the code more clear and easy to maintenance (less duplication)

Comment on lines 276 to 301
def convert_awq_tensor(tensor, tensor_type):
size0 = tensor.size(0)
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor.view(-1,
8)[:,
[0, 4, 1, 5, 2, 6, 3, 7]].view(size0, -1)
tensor = tensor.T.contiguous()
if tensor_type == "qweight":
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
elif tensor_type == "qzeros":
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
return tensor

def convert_gptq_int4_qzeros(tensor):
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor + 1
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
return tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a short description of each transformation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I would add description later.

- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_int4_w8a16 (bool): If True, use matmul of int4 weight and bf16/fp16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- use_int4_w8a16 (bool): If True, use matmul of int4 weight and bf16/fp16
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16

Comment on lines +22 to +23
@triton.jit
def fused_moe_kernel_gptq_awq(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's quite a bit of code duplication between this and fused_moe_kernel - Not necessarily a blocker for this PR but IMO we should refactor and unify these kernels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the beginning I tried to modify the fused_moe_kernel, and found that this made this origin code very complex (with many complex conditions) and hard to read. So I created a new function finally. Not sure what is the best way.

@mgoin
Copy link
Member

mgoin commented Jan 20, 2025

Considering that this is allowing for "another option" to run quantized moe models, maybe we should consider writing a documentation page specifically for moe quantization.

I think the best case for this kernel to be used more broadly would be to have a heuristic on the number of experts or some configuration to decide whether to use the triton or marlin kernel

@jinzhen-lin
Copy link
Contributor Author

Considering that this is allowing for "another option" to run quantized moe models, maybe we should consider writing a documentation page specifically for moe quantization.

I think the best case for this kernel to be used more broadly would be to have a heuristic on the number of experts or some configuration to decide whether to use the triton or marlin kernel

I test with small moe model (https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4) just now, triton kernel seems much faster than marlin kernel too. Besides, marlin kernel seems generate wrong result for Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4.

Test result on A100 * 1:

marlin kernel:

$time curl -X POST http://127.0.0.1:8000/v1/chat/completions     -H 'Content-Type: application/json'     -d '{ "model": "model", "temperature": 0.0, "messages": [ { "role": "user", "content": "write a very long article" } ], "stream": false, "max_tokens": 512, "min_tokens": 512}'
{"id":"chatcmpl-99d4b90d4ee14c15a82bd278b3cbcfd1","object":"chat.completion","created":1737439221,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":"数,数,数,数,数数,数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数数","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":24,"total_tokens":536,"completion_tokens":512,"prompt_tokens_details":null},"prompt_logprobs":null}
real    0m6.618s
user    0m0.002s
sys     0m0.003s

triton kernel

$time curl -X POST http://127.0.0.1:8000/v1/chat/completions     -H 'Content-Type: application/json'     -d '{ "model": "model", "temperature": 0.0, "messages": [ { "role": "user", "content": "write a very long article" } ], "stream": false, "max_tokens": 512, "min_tokens": 512}'
{"id":"chatcmpl-d2f6f9d66c4b40c0b7dad3e1c9fc3d74","object":"chat.completion","created":1737439295,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":"The Benefits of Regular Exercise: A Comprehensive Guide\n\nIntroduction\n\nRegular exercise is an essential component of a healthy lifestyle. It not only helps maintain a healthy weight, but also has numerous other benefits for both physical and mental health. In this article, we will explore the various benefits of regular exercise, including improved cardiovascular health, increased strength and endurance, better mental health, and a reduced risk of chronic diseases. We will also discuss the different types of exercises and how to incorporate them into your routine for maximum effectiveness.\n\nImproved Cardiovascular Health\n\nRegular exercise strengthens the heart and improves its efficiency, reducing the risk of heart disease. Engaging in activities like brisk walking, running, cycling, or swimming can help lower blood pressure, cholesterol levels, and triglyceride levels. This, in turn, reduces the risk of heart attack and stroke. Additionally, regular exercise can also help maintain a healthy weight, which further reduces the strain on the heart.\n\nIncreased Strength and Endurance\n\nExercise helps build muscle strength and endurance, which is crucial for maintaining independence and mobility as we age. Regular strength training, such as weightlifting or bodyweight exercises, can help increase muscle mass, improve bone density, and enhance overall physical performance. This, in turn, can lead to an increased sense of well-being and confidence.\n\nBetter Mental Health\n\nExercise has been shown to have a positive impact on mental health. Physical activity releases endorphins, which are natural mood-boosting chemicals in the brain. Regular exercise can help reduce symptoms of depression and anxiety, improve self-esteem, and increase overall happiness. Additionally, engaging in activities like yoga or meditation can help reduce stress and promote relaxation.\n\nReduced Risk of Chronic Diseases\n\nRegular exercise can significantly reduce the risk of developing chronic diseases such as type 2 diabetes, certain types of cancer, and osteoporosis. Physical activity helps regulate blood sugar levels, which is particularly beneficial for those with diabetes. Exercise also helps maintain a healthy weight, which reduces the risk of developing these diseases. Furthermore, regular physical activity can improve bone density, reducing the risk of osteoporosis.\n\nIncorporating Exercise into Your Routine\n\nTo maximize the benefits of exercise, it's essential to incorporate a variety of activities into your routine. This can include:\n\n1. Cardiovascular exercises: Activities like running, cycling, or swimming can help improve cardiovascular health and burn calories.\n2. Strength training: Incorporating weightlifting or bodyweight exercises can help build muscle mass, increase bone density, and improve overall physical performance.\n3. Flexibility and balance exercises","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":24,"total_tokens":536,"completion_tokens":512,"prompt_tokens_details":null},"prompt_logprobs":null}
real    0m4.016s
user    0m0.000s
sys     0m0.005s

Maybe we should set triton kernel as default moe gptq/awq kernel? But I am not sure how to do this, gptq-marlin-moe is a part of gpt-marlin quanzation method, if I change moe kernel of gptq-marlin method, user cannot use gptq-marlin-moe anyway. Is that ok?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants