Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Runs to completion, but produces garbage
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Jul 12, 2024
1 parent e405879 commit cda9a0f
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 107 deletions.
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
output = torch.empty((num_experts, size_k, size_n),
output = torch.empty((num_experts, size_k // 16, size_n * 2),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
E = w1.shape[0]
N = w2.shape[1] * 16

print("hidden_states shape:", hidden_states.shape)
print("w1 shape:", w1.shape)
print("w2 shape:", w2.shape)
print("gating_output shape:", gating_output.shape)
print("g_idx1 shape:", g_idx1.shape)
print("g_idx2 shape:", g_idx2.shape)
print("w1_scale shape:", w1_scale.shape)
print("w2_scale shape:", w2_scale.shape)

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

Expand Down
143 changes: 115 additions & 28 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional
from typing import Optional, List

import torch

Expand Down Expand Up @@ -69,9 +69,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w13_qweight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size,
# hidden_size * 4,
# intermediate_size * 8,
dtype=params_dtype),
# 2 * intermediate_size,
# hidden_size // self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
Expand All @@ -80,7 +80,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w2_qweight = torch.nn.Parameter(torch.empty(num_experts,
intermediate_size // self.quant_config.pack_factor,
hidden_size,
dtype=params_dtype),
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
Expand Down Expand Up @@ -156,6 +156,8 @@ def apply(self,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:

# print("1", layer.w13_scales)

# TODO translate qweights into Marlin format
if layer.marlin_state == GPTQMarlinState.REPACK:
Expand All @@ -170,14 +172,53 @@ def replace_tensor(name, new_t):
getattr(layer, name).copy_(new_t)
del new_t

def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single

def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()

return s

def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
num_experts = s.shape[0]
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n,
group_size, num_bits)
return output

# print("2", layer.w13_scales)

# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
w13_g_idx_sort_indices = torch.argsort(layer.w13_g_idx).to(torch.int)
w2_g_idx_sort_indices = torch.argsort(layer.w2_g_idx).to(torch.int)

w13_sorted_g_idx = layer.w13_g_idx[w13_g_idx_sort_indices]
w2_sorted_g_idx = layer.w2_g_idx[w2_g_idx_sort_indices]
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e])#.to(torch.int)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e])#.to(torch.int)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]

replace_tensor("w13_g_idx", w13_sorted_g_idx)
replace_tensor("w2_g_idx", w2_sorted_g_idx)
Expand All @@ -203,11 +244,20 @@ def replace_tensor(name, new_t):
requires_grad=False,
)

print(layer.w13_qweight.shape)
print(layer.w2_qweight.shape)
print(x.shape)
# print("3", layer.w13_scales)

print("*")
print("hidden:", x.shape)
print("w13 before:", layer.w13_qweight.shape)
print("w2 before:", layer.w2_qweight.shape)
print("w13 args:", layer.w13_qweight.shape[1]
* self.quant_config.pack_factor,
layer.w13_qweight.shape[2])
print("w2 args:", layer.w2_qweight.shape[1]
* self.quant_config.pack_factor,
layer.w2_qweight.shape[2])

print("weight type:", layer.w13_qweight.dtype)
# print("weight type:", layer.w13_qweight.dtype)

# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
Expand All @@ -218,19 +268,49 @@ def replace_tensor(name, new_t):
self.quant_config.weight_bits,
)
replace_tensor("w13_qweight", marlin_w13_qweight)
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
# layer.w2_qweight,
# layer.w2_g_idx_sort_indices,
# layer.w2_qweight.shape[1] * 8,
# layer.w2_qweight.shape[2] // 2,
# self.quant_config.weight_bits,
# )
# replace_tensor("w2_qweight", marlin_w2_qweight)
# TODO scales

print(layer.w13_qweight.shape)
print(layer.w2_qweight.shape)
print(x.shape)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.weight_bits,
)
replace_tensor("w2_qweight", marlin_w2_qweight)

print("w13 after:", marlin_w13_qweight.shape)
print("w2 after:", marlin_w2_qweight.shape)

print("w13 scales before:", layer.w13_scales.shape)
print("w2 scales before:", layer.w2_scales.shape)
print("w13 args:", x.shape[1], layer.w13_scales.shape[2])
print("w2 args:", layer.w2_scales.shape[1] * self.quant_config.pack_factor,
x.shape[1])

# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_scales,
x.shape[1],
layer.w13_scales.shape[2],
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("w13_scales", marlin_w13_scales)

marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_scales,
layer.w2_scales.shape[1] * self.quant_config.pack_factor,
x.shape[1],
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("w2_scales", marlin_w2_scales)

print("w13 scales after:", marlin_w13_scales.shape)
print("w2 scales after:", marlin_w2_scales.shape)

print(x.shape)
print(layer.w13_qweight.shape)
print(layer.w2_qweight.shape)

return fused_marlin_moe(x,
layer.w13_qweight,
Expand Down Expand Up @@ -364,17 +444,24 @@ def weight_loader(self, param: torch.nn.Parameter,

if is_quantized:
if "_qweight" in weight_name or "_scales" in weight_name:
# if "_scales" in weight_name:
# print("scales:", loaded_weight)
if "w13" in weight_name:
shard_size = self.intermediate_size_per_partition
# print("shard size:", shard_size)
if shard_id == 0:
param_data[expert_id, :, :shard_size] = loaded_weight
# if "_scales" in weight_name:
# print("param:", param_data[expert_id, :, :shard_size])
elif shard_id == 1:
param_data[expert_id, :, shard_size:] = loaded_weight
# if "_scales" in weight_name:
# print("param:", param_data[expert_id, :, shard_size:])
else:
ValueError("wrong shard:", shard_id)
elif "w2" in weight_name:
param_data[expert_id][:] = loaded_weight
# if "_scales" in weight_name:
# print("param:", param_data[expert_id][:])
else:
ValueError("what is this?", weight_name)
elif "_g_idx" in weight_name:
Expand Down
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,13 @@ def replace_tensor(name, new_t):

layer_qweight13 = torch.cat((layer.qweight1, layer.qweight3), 1)

print("*")
print("hidden:", x.shape)
print("w13 before:", layer_qweight13.shape)
print("w2 before:", layer.qweight2.shape)
print("w13 args:", part_size_k, layer_qweight13.shape[1])
print("w2 args:", part_size_n, part_size_k)

# Repack weights
# marlin_qweight1 = ops.gptq_marlin_repack(
# layer.qweight1,
Expand Down Expand Up @@ -873,6 +880,9 @@ def replace_tensor(name, new_t):
)
replace_tensor("qweight13", marlin_qweight13)

print("w13 after:", marlin_qweight13.shape)
print("w2 after:", marlin_qweight2.shape)

# print("done repack", layer.get_parameter("qweight1").shape,
# layer.get_parameter("qweight2").shape,
# layer.get_parameter("qweight3").shape)
Expand All @@ -886,6 +896,11 @@ def replace_tensor(name, new_t):

layer_scales13 = torch.cat((layer.scales1, layer.scales3), 1)

print("w13 scales before:", layer_scales13.shape)
print("w2 scales before:", layer.scales2.shape)
print("w13 args:", part_size_k, layer_qweight13.shape[1])
print("w2 args:", layer.scales2.shape[0] * 8, layer.scales2.shape[1])

# marlin_scales1 = marlin_permute_scales(
# layer.scales1,
# scales_size_k,
Expand Down Expand Up @@ -919,6 +934,11 @@ def replace_tensor(name, new_t):
)
replace_tensor("scales13", marlin_scales13)

print("w13 scales after:", marlin_scales13.shape)
print("w2 scales after:", marlin_scales2.shape)

# raise ValueError("stop")

# else:
# print("do not repack")

Expand Down
Loading

2 comments on commit cda9a0f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smaller_is_better

Benchmark suite Current: cda9a0f Previous: 0981a60 Ratio
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 41.969406278803945 ms
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 7.744400125667221 ms
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 30.434241127222776 ms
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 11.31761793943082 ms

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smaller_is_better

Benchmark suite Current: cda9a0f Previous: 0981a60 Ratio
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 190.32799008667098 ms 183.78704185992925 ms 1.04
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 87.12298690267654 ms 86.42179767341695 ms 1.01
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 28.015013069999288 ms 26.223337443346583 ms 1.07
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 7.139864341273219 ms 6.787514831128281 ms 1.05

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.