From c048f22b46a227a99990d7ca9198075aafccc738 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 11 Nov 2024 06:43:42 +0000 Subject: [PATCH] Delete unused libentry Signed-off-by: Jee Jee Li --- tests/lora/test_punica_sizes.py | 73 ++++-------- tests/lora/test_punica_variation.py | 73 ++++-------- vllm/lora/ops/sgmv_expand.py | 3 - vllm/lora/ops/sgmv_expand_slice.py | 3 - vllm/lora/ops/sgmv_shrink.py | 3 - vllm/triton_utils/__init__.py | 3 +- vllm/triton_utils/libentry.py | 167 ---------------------------- 7 files changed, 49 insertions(+), 276 deletions(-) delete mode 100644 vllm/triton_utils/libentry.py diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index e756544d96e98..66b5f82bbb97d 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,8 +4,6 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ -from unittest.mock import patch - import pytest import torch @@ -16,7 +14,6 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.platforms import current_platform -from vllm.triton_utils.libentry import LibEntry from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -235,9 +232,6 @@ def test_punica_bgmv( seed: int, device: str, ): - from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel - from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - torch.set_default_device(device) current_platform.seed_everything(seed) @@ -262,33 +256,21 @@ def test_punica_bgmv( device, ) if op_type == "shrink": - # The current _bgmv_shrink_kernel does not require the libentry - # decoration. The purpose of adding this patch is to test the - # correctness of libentry. - with patch( - "vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel", - LibEntry(_bgmv_shrink_kernel), - ): - bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) + bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) else: - # ditto - with patch( - "vllm.lora.ops.bgmv_expand._bgmv_expand_kernel", - LibEntry(_bgmv_expand_kernel), - ): - bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) + bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) ref_torch_groupgemm( ref_out_tensor, inputs_tensor, @@ -324,7 +306,6 @@ def test_punica_expand_nslices( seed: int, device: str, ): - from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel torch.set_default_device(device) current_platform.seed_everything(seed) @@ -374,22 +355,16 @@ def test_punica_expand_nslices( add_inputs=True, ) else: - # The current _bgmv_expand_slice_kernel does not require the - # libentry decoration. The purpose of adding this patch is to test - # the correctness of libentry. - with patch( - "vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel", - LibEntry(_bgmv_expand_slice_kernel), - ): - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], inputs_tensor, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index dc0edeb10ef46..52b82f25d23e1 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,8 +3,6 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ -from unittest.mock import patch - import pytest import torch @@ -15,7 +13,6 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.platforms import current_platform -from vllm.triton_utils.libentry import LibEntry from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -150,8 +147,6 @@ def test_punica_bgmv( seed: int, device: str, ): - from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel - from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel torch.set_default_device(device) current_platform.seed_everything(seed) @@ -177,33 +172,22 @@ def test_punica_bgmv( device, ) if op_type == "shrink": - # The current _bgmv_shrink_kernel does not require the libentry - # decoration. The purpose of adding this patch is to test the - # correctness of libentry. - with patch( - "vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel", - LibEntry(_bgmv_shrink_kernel), - ): - bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) + bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) else: - # ditto - with patch( - "vllm.lora.ops.bgmv_expand._bgmv_expand_kernel", - LibEntry(_bgmv_expand_kernel), - ): - bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) + + bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) ref_torch_groupgemm( ref_out_tensor, inputs_tensor, @@ -239,8 +223,6 @@ def test_punica_expand_nslices( seed: int, device: str, ): - from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - torch.set_default_device(device) current_platform.seed_everything(seed) @@ -289,22 +271,15 @@ def test_punica_expand_nslices( add_inputs=True, ) else: - # The current _bgmv_expand_slice_kernel does not require the - # libentry decoration. The purpose of adding this patch is to test - # the correctness of libentry. - with patch( - "vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel", - LibEntry(_bgmv_expand_slice_kernel), - ): - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], inputs_tensor, diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index adb3ab5b46b87..4910cb4061298 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -9,10 +9,7 @@ import triton import triton.language as tl -from vllm.triton_utils import libentry - -@libentry() @triton.jit def _sgmv_expand_kernel( input_ptr, diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index efa234520ab87..844f5cec39e93 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -9,10 +9,7 @@ import triton import triton.language as tl -from vllm.triton_utils import libentry - -@libentry() @triton.jit def _sgmv_expand_slice_kernel( input_ptr, diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c003f3dc0ce9e..b4d893047b06b 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -9,10 +9,7 @@ import triton import triton.language as tl -from vllm.triton_utils import libentry - -@libentry() @triton.jit def _sgmv_shrink_kernel( input_ptr, diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 3f57c22e1f2e4..568185383aa5c 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -6,6 +6,5 @@ from vllm.triton_utils.custom_cache_manager import ( maybe_set_triton_cache_manager) - from vllm.triton_utils.libentry import libentry - __all__ += ["maybe_set_triton_cache_manager", "libentry"] + __all__ += ["maybe_set_triton_cache_manager"] diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py deleted file mode 100644 index 4335c7adfc13b..0000000000000 --- a/vllm/triton_utils/libentry.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copied From https://github.com/FlagOpen/FlagGems - -import inspect - -import triton - - -class LibEntry(triton.KernelInterface): - - def __init__( - self, - fn, - ): - self.fn = fn - self.arg_names = fn.arg_names - self.divisibility = 16 - self.kernel_cache = dict() - fn = self.fn - while not isinstance(fn, triton.runtime.JITFunction): - fn = fn.fn - self.jit_function: triton.runtime.JITFunction = fn - self.specialize_indices = [ - p.num for p in self.jit_function.params - if not p.is_constexpr and not p.do_not_specialize - ] - self.do_not_specialize_indices = [ - p.num for p in self.jit_function.params - if not p.is_constexpr and p.do_not_specialize - ] - - def key(self, spec_args, dns_args, const_args): - spec_key = [(arg.dtype, arg.data_ptr() % - self.divisibility == 0) if hasattr(arg, "data_ptr") else - (type(arg), arg) for arg in spec_args] - dns_key = [ - arg.dtype if hasattr( - arg, "data_ptr") else type(arg) if not isinstance(arg, int) - else "i32" if arg >= -(2**31) and arg <= 2**31 - - 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64" - for arg in dns_args - ] - # const args passed by position - return tuple(spec_key + dns_key + const_args) - - def run(self, *args, **kwargs): - grid = kwargs["grid"] - # collect all the arguments - spec_args = [] # specialize arguments - dns_args = [] # do not specialize arguments - const_args = [] # constexpr arguments - k_args = [] # kernel arguments - for i, arg in enumerate(args): - if i in self.specialize_indices: - k_args.append(arg) - spec_args.append(arg) - elif i in self.do_not_specialize_indices: - k_args.append(arg) - dns_args.append(arg) - else: - const_args.append(arg) - for p in self.jit_function.params[len(args):]: - if p.name in kwargs: - val = kwargs[p.name] - elif p.default is inspect._empty: - continue - else: - val = p.default - - if p.is_constexpr: - const_args.append(val) - elif p.do_not_specialize: - dns_args.append(val) - k_args.append(val) - else: - spec_args.append(val) - k_args.append(val) - - entry_key = self.key(spec_args, dns_args, const_args) - - if entry_key not in self.kernel_cache: - # compile the kernel also completes the related computations - kernel = self.fn.run(*args, **kwargs) - fn = self.fn - # collect constexpr arguments for grid computation - constexprs = {} - while not isinstance(fn, triton.runtime.JITFunction): - if isinstance(fn, triton.runtime.Autotuner): - config = fn.best_config - constexprs["num_warps"] = config.num_warps - constexprs["num_stages"] = config.num_stages - constexprs["num_ctas"] = config.num_ctas - constexprs = {**constexprs, **config.kwargs} - elif isinstance(fn, triton.runtime.Heuristics): - for v, heur in fn.values.items(): - constexprs[v] = heur({ - **dict(zip(fn.arg_names, args)), - **kwargs, - **constexprs, - }) - else: - raise RuntimeError("Invalid Runtime Function") - fn = fn.fn - # In vLLM, certain kernels like fused_moe_kernel get the - # best_config(as kwargs) from a configuration json file, rather - # than using Autotuner & Heuristics. Therefore, all their constexprs - # (tl.constexpr) are assigned values through the following loop. - for p in self.jit_function.params: - if p.is_constexpr and p.name not in constexprs: - constexprs[p.name] = p.default #default=inspect._empty - self.kernel_cache[entry_key] = (kernel, constexprs) - else: - # load kernel from cache directly - kernel, constexprs = self.kernel_cache[entry_key] - - if callable(grid): - # collect all arguments to the grid fn,ie: - # 1. args, - # 2. kwargs, - # 3. all all other captured arguments in CompiledKernel from - # Autotunner & Heuristics when kwargs & captured args conflict, - # captured args have higher priority - # 4. We must filter out captured args with default value firstly - constexprs = { - k: v - for k, v in constexprs.items() if v is not inspect._empty - } - meta = { - **dict(zip(self.arg_names, args)), - **kwargs, - **constexprs, - } - grid = grid(meta) - if isinstance(grid, tuple): - grid = grid + (1, 1) - elif isinstance(grid, list): - grid = grid + [1, 1] - kernel[grid[0:3]](*k_args) - # maintaining the same return type as the JITFunction.run - return kernel - - -def libentry(): - """ - Decorator for triton library entries. - Motivation: - The runtime overhead of Triton kernels is the reason for the lower - performance of small kernels, particularly evident with smaller models. - Using this decorator can reduce Triton runtime overhead. - How: - The `run` function of JITFunction needs to accomplish: - - Parameter binding using inspect - - KernelArg type wrapping - - Cache key calculation - When dealing with small size, these steps can become bottlenecks in - Triton runtime. Libentry simplifies these steps to reduce runtime - overhead, thereby improving the runtime expenses of small kernels. - NOTE: - When Triton is upgraded to version 3.0.0, libentry can be removed, - see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 - - - """ - - def decorator(fn): - return LibEntry(fn) - - return decorator