Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Marlin moe integration #266

Closed
wants to merge 48 commits into from
Closed

Marlin moe integration #266

wants to merge 48 commits into from

Conversation

ElizaWszola
Copy link

@ElizaWszola ElizaWszola commented May 24, 2024

Unit testing:

pytest tests/kernels/test_moe.py -k test_fused_marlin_moe

(requires to uncomment @pytest.mark.skip in test_moe.py).

End-to-end testing:
Run offline_inference.py with

// quantized moe with act order
llm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ",
          revision="gptq-4bit-128g-actorder_True",
          enforce_eager=True)

and

// quantized moe without act order
llm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ", enforce_eager=True)

return torch.stack(tensors, dim=0).to(dev)


def fused_marlin_moe(
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not need to adhere to the exact same interface as fused_moe

This function will be called on the hotpath. It should receive INT4 weights and scales and just call the marlin moe kernel directly

qweights1 = []
scaless1 = []

for i in range(w1.shape[0]):
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not be called on the hotpath.

Rather, the quantized weights should be an input to this function

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comments in code

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented May 24, 2024

So a couple things. vLLM is layed out in the following way

Models --> llama.py, which uses linear_layers like ColumnParallelLinear

Each layer has a LinearMethod which handles the representation of the weights and the forward pass
For instance, we have an Fp16 linear method, Marlin linear method, etc., etc.

Each LinearMethod exposes the following interface:
create_weights --> defines what the weights look like (e.g. dtype, whether there are scales, etc)
apply --> calls the kernels during the forward pass

So, we will eventually want to create a LinearMethod for MarlinMoE
create_weights --> will load the int4 weights from disk
apply --> passes the weights to your fused_moe_marlin kernel

As a result, the fused_moe_marlin kernel should recieve already quantized weights and just execute the computation. We should not be quantizing the weights inside of that function

For this PR, we should land the kernel + testing code for the kernel. we can work on adding the LinearMethod afterwards

@@ -477,3 +478,342 @@ def fused_moe(
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per my comment below, we will load the already compressed weights via create_weights.

So none of this will need to be called on the hotpath

As a result, all of this should be moved into testing utilities

vllm/_custom_ops.py Outdated Show resolved Hide resolved
csrc/moe/marlin_moe_ops.h Outdated Show resolved Hide resolved
vllm/_custom_ops.py Outdated Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add a test to make sure this works on other GPUs as well (we do this in the cutlass unit tests, if you want to replicate that here)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean testing on different devices? (@pytest.mark.parametrize("device", CUDA_DEVICES))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing anything on cuda:1 results in memory erros (illegal access) in moe_align_block_size_kernel which I rely on, but didn't modify - should I look into it or is it ok to leave it for now?

# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is 16 a hardcoded block size?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to Marlin format which is hardcoded

vllm/model_executor/layers/fused_moe/fused_moe.py Outdated Show resolved Hide resolved
Comment on lines 145 to 148
w1_s = self.experts[i].w1.get_parameter("scales").half()
w3_s = self.experts[i].w3.get_parameter("scales").half()
w2_qw = self.experts[i].w2.get_parameter("qweight").int()
w2_s = self.experts[i].w2.get_parameter("scales").half()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these guaranteed to be fp16?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Pytorch documentation: self.half() is equivalent to self.to(torch.float16)
Scales are not necessarily fp16 when loaded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much of this file is copy-pasted from the original marlin code? Could we factor out common functions? It will make it much easier to review if we can see what the new code is

Copy link
Author

@ElizaWszola ElizaWszola Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is quite a bit of overlap, and many of changes boil down to adding one variable or an extra condition here and there. I don't really want to refactor into common functions until act_order is done, because there might be more of these tiny modifications (or is it better to do the refactor now?).

In any case, running a comparison of this file against csrc/quantization/gptq_marlin/gptq_marlin.cu helps seeing what changed.

Edit: fixed file name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s fair for things that may be changed by act_reorder but any functions that are copied over unmodified should be factored out IMO

@ElizaWszola
Copy link
Author

Moved to the other repo

@ElizaWszola ElizaWszola closed this Aug 2, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants