From 3fdd4fc78d7cdb6f495298ea51d7de8ef593a9b8 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 24 Sep 2024 19:21:12 -0500 Subject: [PATCH] Hook the einsum kernel and the __getitem__ operation into their implementations --- .../sharktank/kernels/einsum_2args_q4.py | 79 +++++++------------ .../kernels/templates/einsum_2args_q4.mlir | 5 +- sharktank/sharktank/ops/custom_impls.py | 13 +++ sharktank/sharktank/ops/default_impls.py | 34 +++++++- sharktank/sharktank/ops/signatures.py | 57 +++++++++++++ sharktank/sharktank/types/tensors.py | 8 +- 6 files changed, 140 insertions(+), 56 deletions(-) diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py index 75d76b8cd..0f8e63e98 100644 --- a/sharktank/sharktank/kernels/einsum_2args_q4.py +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -27,30 +27,28 @@ def einsum_util(einsum_str): lmap[es_out[i]] = i count = len(es_out) for c in es_set: - if c not in lmap: + if c not in lmap: imap[count] = c lmap[c] = count count += 1 - assert count == len(es_set) - + assert(count == len(es_set)) + in0_idx = [lmap[i] for i in es_in0] in1_idx = [lmap[i] for i in es_in1] out_idx = [lmap[i] for i in es_out] - - input_idx_str = ", ".join(["d" + str(i) for i in range(size)]) - in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx]) - in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx]) - out_idx_str = ", ".join(["d" + str(i) for i in out_idx]) - - iterators = ", ".join( - ['"parallel"' if i in out_idx else '"reduction"' for i in range(size)] - ) - + + input_idx_str = ", ".join(["d"+str(i) for i in range(size)]) + in0_idx_str = ", ".join(["d"+str(i) for i in in0_idx]) + in1_idx_str = ", ".join(["d"+str(i) for i in in1_idx]) + out_idx_str = ", ".join(["d"+str(i) for i in out_idx]) + + iterators = ", ".join(["\"parallel\"" if i in out_idx else "\"reduction\"" for i in range(size)]) + affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>" affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>" affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>" - + indexing_maps = f"""{affine_map_in0}, {affine_map_in1}, {affine_map_out} @@ -69,13 +67,7 @@ def einsum_util(einsum_str): printf("invalid einsum string") exit(1) out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] - return ( - (in0_idx, in1_idx, out_idx), - iterators, - indexing_maps, - out_dyn_dim_size_str, - ) - + return (in0_idx, in1_idx, out_idx), iterators, (affine_map_in0, affine_map_in1, affine_map_out), indexing_maps, out_dyn_dim_size_str @CustomOp.register(library=LIBRARY) class einsum_2args_q4(CustomOp): @@ -92,9 +84,7 @@ class einsum_2args_q4(CustomOp): will be specialized for all values of N, K and LHS dtype. """ - signature = ( - "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" - ) + signature = "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" def select(self, ksel: KernelSelection): a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k @@ -117,20 +107,23 @@ def select(self, ksel: KernelSelection): # d arg *d_dims, d_group0, d_one = d_desc.t.shape torch._check( - d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims), + d_group0 == qs_group0 + and d_one == 1 + and len(d_dims) == len(qs_dims), lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})", ) # m arg *m_dims, m_group0, m_one = m_desc.t.shape torch._check( - m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), + m_desc.t.dtype == d_desc.t.dtype + and len(m_dims) == len(qs_dims), lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", ) - # einsum_str torch._check( - einsum_str.count(",") == 1 and einsum_str.count("->") == 1, + einsum_str.count(",") == 1 + and einsum_str.count("->") == 1, lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')", ) @@ -139,14 +132,10 @@ def select(self, ksel: KernelSelection): es_set = set(es_out) shp = qs_desc.t.shape - print(shp) b_dims = list(shp[:-2]) + [shp[-2] * block_size] - print(b_dims) torch._check( len(es_in0) == len(a_desc.t.shape) - and len(es_in1) - == len(qs_desc.t.shape) - - 1, # The quantized shape is larger until the blocks are collapsed + and len(es_in1) == len(qs_desc.t.shape) - 1, # The quantized shape is larger until the blocks are collapsed lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", ) torch._check( @@ -203,28 +192,19 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): qs = kb.arg_value(2) qs_tensor_type = RankedTensorType(qs.type) einsum_str = ksel.arg_descs[4].v - # einsum_str = "mek,menk->men" + #einsum_str = "mek,menk->men" es_in, es_out = einsum_str.split("->") es_in0, es_in1 = es_in.split(",") es_name = "_".join([es_in0, es_in1, es_out]) - - ( - (es_0, es_1, es_2), - einsum_iterators, - einsum_indexing_maps, - oddss, - ) = einsum_util(einsum_str) - + + (es_0, es_1, es_2), einsum_iterators, _, einsum_indexing_maps, oddss = einsum_util(einsum_str) + rank1 = len(es_1) - dequant_iterators = ", ".join( - ['"parallel"' for i in range(rank1 + 1)] - ) # rank + 1 because of the group dimensions - input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)]) - broadcast_idx_str = ", ".join( - ["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)] - ) + dequant_iterators = ", ".join(["\"parallel\"" for i in range(rank1 + 1)]) # rank + 1 because of the group dimensions + input_idx_str = ", ".join(["d"+str(i) for i in range(rank1 + 1)]) + broadcast_idx_str = ", ".join(["d"+str(i) if i != rank1 else "0" for i in range(rank1 + 1)]) affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>" affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>" dequant_indexing_maps = f"""{affine_map_broadcast}, @@ -262,5 +242,4 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): c_size=len(es_out), out_dyn_dim_size_str=oddss, ) - print(target_function) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir index 9c5a49fa8..9003b48b7 100644 --- a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -25,6 +25,7 @@ module { util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) -> !c_tensor_type { + %debug = tensor.empty() : tensor<1xf32> %zero = arith.constant 0.0: !accum_type // todo: loop {% for i in range(a_size) %} @@ -41,9 +42,9 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( {% endfor %} %bs = arith.constant {{bs}} : index %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index - + //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type - %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}} + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} // Dequantize. // todo: loop diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index fe6ae27b1..d1df15cbc 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from ..kernels import ( + einsum_2args_q4, mmt_block_scaled_offset_q4_unsigned, mmt_block_scaled_q8, mmtfp, @@ -44,6 +45,18 @@ # return mmtfp(lhs, rhs) +# Einsum + + +@einsum_2args.override(Tensor, QuantizedTensor) +def einsum_2args_QuantizedTensor(input0, input1, einsum_str): + unpacked = input1.unpack() + layout = input1.layout_type + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str) + + # Quantized Matmul diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 0ab6053c2..46b1bd9e5 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor, PlanarQuantizedTensor, BlockScaledI4Layout from ..types.tensors import unbox_tensor, AnyTensor from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * @@ -62,6 +62,11 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) +# Einsum +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor)) +def einsum_2args(x, y, einsum_str): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + # Elementwise @elementwise.override(Tensor) def elementwise_unary(operator, x): @@ -133,6 +138,33 @@ def equal_default(a, b) -> bool: return torch.equal(unbox_tensor(a), unbox_tensor(b)) +@get_index.override( + AllOfType(Tensor, PrimitiveTensor) +) +def get_index_default(tensor, key: slice): + return unbox_tensor(tensor).__get_item__(key) + + +@get_index.override(QuantizedTensor) +def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice): + unpacked = tensor.unpack() + if isinstance(unpacked, BlockScaledI4Layout): + mul = 2 + else: + return NotImplemented + new_d = unpacked._d[key] + new_qs = unpacked._qs[key] + if unpacked.m is not None: + new_m = unpacked.m[key] + dims = new_qs.shape + dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) + layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=dims, layout=layout) + + +#get_index.override(PlanarQuantizedTensor, slice)(get_index_QuantizedTensor) + + @gemm.override(AllOfType(Tensor, InferenceTensor)) def gemm( a: AnyTensor, diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index d39aba71c..c9d2b2d1e 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -21,9 +21,11 @@ "all_reduce", "cat", "conv2d", + "einsum_2args", "elementwise", "embedding_lookup", "equal", + "get_index", "gemm", "group_norm_affine", "layer_norm", @@ -151,6 +153,35 @@ def _conv2d_trampoline( d.fail(tensors) +@overridable +def einsum_2args( + input0: AnyTensor, + input1: AnyTensor, + einsum_str: str, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Executes a given Einstein summation notation string on the provided tensors. + + Equivalent to: + ``` + y = torch.einsum(einsum_str, input0, input1) + ``` + """ + raise NotImplementedError + + +@einsum_2args.trampoline +def _einsum_trampoline(d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str): + tensors = (input0, input1) + for override in d.find_overrides(tensors): + result = override(input0, input1, einsum_str) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def elementwise(operator, *args: AnyTensor) -> AnyTensor: """Applies an elementwise operator against arguments.""" @@ -232,6 +263,32 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): d.fail(tensors) +@overridable +def get_index( + tensor: AnyTensor, + key: slice, +) -> torch.Tensor: + """Indexes the tensor using the key. + + Equivalent to: + ``` + out = tensor[key] + ``` + """ + raise NotImplementedError + + +@get_index.trampoline +def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, key) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def gemm( a: AnyTensor, diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 3545608a9..1b06eeb12 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -298,6 +298,10 @@ def __rmul__(self, lhs): # numbers on the lhs. return self.__mul__(lhs) + def __getitem__(self, key): + from ..ops import get_index + return get_index(self, key) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} @@ -433,9 +437,7 @@ def to_planar(self) -> "PlanarQuantizedTensor": it should override this method to implement properly or raise NotImplementedError. """ - return PlanarQuantizedTensor( - name=self.name, shape=self.shape, layout=self.unpack() - ) + return PlanarQuantizedTensor(name=self.name, shape=self.shape, layout=self.unpack()) def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata: """By default all QuantizedTensors serialize as a generic PlanarQuantizedTensor.