Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix fully sharded LoRAs with Mixtral #11390

Merged
merged 1 commit into from
Dec 22, 2024

Conversation

n1hility
Copy link
Contributor

@n1hility n1hility commented Dec 21, 2024

Fixes a regression introduced by #9008 , which leads to an assertion error when --fully-sharded-loras is enabled with an adaptor that includes a gate target:

vllm/lora/models.py:352: in __init__
    self._create_lora_modules()
vllm/lora/models.py:506: in _create_lora_modules
    self.register_module(module_name, new_module)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <vllm.lora.models.LRUCacheLoRAModelManager object at 0x7ff5ea843e10>, module_name = 'model.layers.0.block_sparse_moe.gate'
module = ReplicatedLinear(in_features=4096, output_features=8, bias=False)

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
>       assert isinstance(module, BaseLayerWithLoRA)
E       AssertionError

This occurs because Mixtral includes a ReplicatedLinear layer for the MoE gate, and #7081 marked it as @_not_fully_sharded_can_replace. Since these are per-GPU and the implementation looks safe I assume this was just an unintentional copy of the decorator.

This PR removes it so that ReplicatedLinearWithLora will replace ReplicatedLinear regardless of whether fully shared LoRAs are enabled, and updates the test scenario to cover both values.

Let me know if I am missing anything.

Thanks!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your fixing. Currently Mixtral is not being tested on CI. I assume you have already tested it locally.

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 21, 2024
@n1hility
Copy link
Contributor Author

Thanks for your review @jeejeelee! I ran the test LoRA Mixtral tests locally (before and after change) as well as test interaction with a booted vllm instance using Mixtral with LoRA adapters. I can look for and run additional Mixtral tests in the vllm testsuite locally if these aren’t always run.

@n1hility
Copy link
Contributor Author

Thanks for your review @jeejeelee! I ran the test LoRA Mixtral tests locally (before and after change) as well as test interaction with a booted vllm instance using Mixtral with LoRA adapters. I can look for and run additional Mixtral tests in the vllm testsuite locally if these aren’t always run.

@jeejeelee These all passed too ( I ran the other Mixtral tests outside of tests/lora (in addition to the ones in tests/lora))

It looks like the reported CI failures in entrypoints-test are also failing on main so those are unrelated.

    - Changes ReplicatedLinearWithLoRA to always apply regardless of
      the fully sharded LoRA setting, since in both cases the layer
      needs to be replicated
    - Updates the existing mixtral all modeuls test to test both values
      of fully_sharded_loras (which includes a ReplicatedLayer [gate])

Signed-off-by: Jason Greene <[email protected]>
@n1hility n1hility force-pushed the fix-sharded-loras-mixtral branch from d08cbf3 to b13f3b9 Compare December 22, 2024 06:57
@jeejeelee jeejeelee merged commit f1d1bf6 into vllm-project:main Dec 22, 2024
55 checks passed
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants