Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
lots of debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Jun 24, 2024
1 parent c62bc7f commit e65c195
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 81 deletions.
26 changes: 24 additions & 2 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,21 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
scale(frag_b0, frag_s[k % 2][j], 0);
}

// if (threadIdx.x == 0 && blockIdx.x == 0 && expert_idx == 0 && j == 0) {
// printf("dequant: %f %f %f %f\n", __high2float(frag_b0[0]), __low2float(frag_b0[0]), __high2float(frag_b0[1]), __low2float(frag_b0[1]));
// }

FragB frag_b1 = dequant(b_quant_shift);

// Apply scale to frag_b1
if constexpr (group_blocks != -1) {
scale(frag_b1, frag_s[k % 2][j], 1);
}

// if (threadIdx.x == 0 && blockIdx.x == 0 && expert_idx == 0 && j == 0) {
// printf("dequant: %f %f %f %f\n", __high2float(frag_b1[0]), __low2float(frag_b1[0]), __high2float(frag_b1[1]), __low2float(frag_b1[1]));
// }

#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
Expand Down Expand Up @@ -1200,10 +1208,15 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids
int tot_m = prob_m;

long* expert_offsets_ptr = (long*)expert_offsets;
// for (int expert_idx = 0; expert_idx < num_experts + 1; ++expert_idx) {
// printf("%ld ", expert_offsets_ptr[expert_idx]);
// }
// printf("\n");

// printf("run loop for %d %d %d and topk: %d\n", prob_m, prob_n, prob_k, topk);

for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
// TODO bring 1 back to num_experts
for (int expert_idx = 0; expert_idx < 1; ++expert_idx) {
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx;
int4* C_ptr = (int4*)C;
Expand Down Expand Up @@ -1279,7 +1292,7 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
int dev = a.get_device();

auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, topk, size_n}, options);
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options);

// thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
int thread_k = -1;
Expand Down Expand Up @@ -1309,6 +1322,15 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc

long* eoff_f = (long*)(expert_offsets.data_ptr());

// printf("a %d\n", a.is_cuda());
// printf("b_q_weights %d\n", b_q_weights.is_cuda());
// printf("c %d\n", c.is_cuda());
// printf("sorted_ids %d\n", sorted_ids.is_cuda());
// printf("topk_weights %d\n", topk_weights.is_cuda());
// printf("b_scales %d\n", b_scales.is_cuda());
// printf("expert_offsets %d\n", expert_offsets.is_cuda());
// printf("workspace %d\n", workspace.is_cuda());

// printf("offf: %ld %ld %ld\n", eoff_f[0], eoff_f[1], eoff_f[2]);

marlin_moe::marlin_mm_moe_f16i4(a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(),
Expand Down
97 changes: 51 additions & 46 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe, fused_marlin_moe, single_marlin_moe
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.model_executor.models.mixtral_quant import MixtralMoE as MixtralMoEQuant

if should_skip_test_group(group_name="TEST_KERNELS"):
pytest.skip("TEST_KERNELS=DISABLE, skipping kernels test group",
Expand All @@ -29,38 +30,12 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# for idx in range(4):
# topk_ids[idx * 2] = 0
# topk_ids[idx * 2 + 1] = 1
print("topk:", topk_ids)
for i in range(w1.shape[0]):
mask = topk_ids == i
print("mask:", mask)
if mask.sum():
# print("silu and mul:", (a[mask] @ w1[i].transpose(0, 1)).size(),
# "->", SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)).size(),
# "->", (SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)).size(),)
simul = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1))
# print("out shape:", out[mask].size(), "simul shape:", simul.size())
print("simul torch:", simul)
print("w2 t:", w2[i].transpose(0, 1))
out[mask] = simul @ w2[i].transpose(0, 1)
# for i in range(1):
# mask = topk_ids == i
# # print("mask:", mask, mask.size())
# # print("silu and mul:", (a[mask] @ w1[i].transpose(0, 1)).size(),
# # "->", SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)).size(),
# # "->", (SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)).size(),)
# inter1 = a[mask] @ w1[i].transpose(0, 1)
# print("intermediate 1 torch:", inter1, inter1.size())
# simul = SiluAndMul()(inter1)
# # print("out shape:", out[mask].size(), "simul shape:", simul.size())
# print("simul torch:", simul, simul.size())
# print("w2 t:", w2[i].transpose(0, 1))
# out[mask] = simul @ w2[i].transpose(0, 1)
# return out.view(B, -1, w2.shape[1]).sum(dim=1)
# return out#.view(B, -1, w2.shape[1])
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)

Expand Down Expand Up @@ -108,10 +83,10 @@ def test_fused_moe(


# UPSTREAM SYNC: breaks NM automation.
@pytest.mark.skip("C compiler not installed in NM automation. "
"This codepath follows a triton pathway, which "
"JITs using clang or gcc. Since neither are installed "
"in our test instances, we need to skip this for now.")
# @pytest.mark.skip("C compiler not installed in NM automation. "
# "This codepath follows a triton pathway, which "
# "JITs using clang or gcc. Since neither are installed "
# "in our test instances, we need to skip this for now.")
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
Expand Down Expand Up @@ -415,21 +390,10 @@ def test_fused_marlin_moe(
scales2 = stack_and_dev(scaless2)

score = torch.randn((m, e), device='cuda', dtype=dtype)
# score = torch.ones((m, e), device='cuda', dtype=dtype)
triton_output = fused_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2), score, topk, renormalize=False)
# triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
marlin_output = fused_marlin_moe(m, n, k, e, a, qweight1, qweight2, score, topk,
marlin_output = fused_marlin_moe(a, qweight1, qweight2, score, topk,
renormalize=False, w1_scale=scales1, w2_scale=scales2)

# print("marlin out:", marlin_output)
# print("triton out:", triton_output)
# print(marlin_output.size())
# print(triton_output.size())

# print(compute_max_diff(marlin_output, triton_output))
# assert(True)

# assert(compute_max_diff(marlin_output, triton_output) < 100)
assert(compute_max_diff(marlin_output, triton_output) < 4e-2)

# UPSTREAM SYNC: breaks NM automation.
Expand Down Expand Up @@ -473,9 +437,6 @@ def test_single_marlin_moe(
# w[ii][jj][kk*gran:(kk+1)*gran] = sav_w * inc
# inc += 0.01

# print ("w size:", w.size())
# print("w:", w)

w_refs = []
qweights = []
scaless = []
Expand Down Expand Up @@ -504,7 +465,7 @@ def test_single_marlin_moe(
# print(w.size(), "->", qweight.size())

score = torch.randn((m, e), device='cuda', dtype=dtype)
marlin_output = single_marlin_moe(m, n, k, e, a, qweight, scales, score, topk, renormalize=False)
marlin_output = single_marlin_moe(a, qweight, scales, score, topk, renormalize=False)
torch_output = torch_moe_small(a, w_ref.transpose(1, 2), score, topk)

# print(marlin_output.size(), torch_output.size())
Expand All @@ -527,3 +488,47 @@ def test_single_marlin_moe(
# print(mm[0], mm[0] // n, mm[0] // n * n, "+", mm[0] % n, mm[1].item(), tt[1].item())

assert(compute_max_diff(marlin_output, torch_output) < 1e-2)

from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
from vllm.distributed.parallel_state import initialize_model_parallel, init_distributed_environment
import os

@torch.inference_mode()
def test_forward():
m = 8
n = 256
k = 128
# e = 2
# topk = 2
group_size = -1
dtype = torch.float16

init_distributed_environment(1, 0, "tcp://192.168.198.114:60519", 0, "nccl")
initialize_model_parallel()

a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w = torch.randn((n, k), device='cuda', dtype=dtype) / 10

quant_config = GPTQMarlinConfig(4, group_size, False, True)
layer = ReplicatedLinear(k, n, bias=False, quant_config=quant_config, params_dtype=dtype)
print(layer.qweight.data)
# layer.weight = w

# making w:
# GPTQMarlinLinearMethod.create_weights(self, self.input_size,
# [self.output_size], self.input_size,
# self.output_size, self.params_dtype)

# config = MixtralConfig()
# hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
# vllm_moe = MixtralMoEQuant(
# config, quant_config
# ).cuda()

# for i in range(1):
# w1 = vllm_moe.experts[i].w1.weight.data
# w3 = vllm_moe.experts[i].w3.weight.data

# out = layer(a)
# print(out)
1 change: 1 addition & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
# print("after call:", b_q_weight.shape, size_k, size_n)
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)

Expand Down
5 changes: 5 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def init_distributed_environment(
local_rank: int = -1,
backend: str = "nccl",
):
print("IDE distributed")
print("world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)

logger.debug(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, fused_marlin_moe, single_marlin_moe, get_config_file_name)
fused_experts, fused_moe, fused_topk, fused_marlin_moe, fused_marlin_moe_2, single_marlin_moe, get_config_file_name)

__all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"fused_marlin_moe",
"fused_marlin_moe_2",
"single_marlin_moe",
"get_config_file_name",
]
Loading

4 comments on commit e65c195

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: e65c195 Previous: 9b2e107 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.11.4 (main, Jun 7 2023, 11:01:02) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 2.450919848219197 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.11.4 (main, Jun 7 2023, 11:01:02) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 941.1532217161717 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: e65c195 Previous: 9b2e107 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 2.455020668301758 prompts/s 2.477025972502708 prompts/s 1.01
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 942.7279366278751 tokens/s 951.1779734410399 tokens/s 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: e65c195 Previous: 9b2e107 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.8.17 (default, Jun 7 2023, 12:29:56) \n[GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 2.482192182644343 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.8.17 (default, Jun 7 2023, 12:29:56) \n[GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 953.1617981354278 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: e65c195 Previous: 9b2e107 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.9.17 (main, Jun 7 2023, 12:34:12) \n[GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 2.4267567056555888 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.0", "python_version": "3.9.17 (main, Jun 7 2023, 12:34:12) \n[GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 931.8745749717461 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.