From 227f8c45e333ddea1e0d0a9f7cf01d5c6563b438 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Sat, 21 Sep 2024 01:02:18 -0500 Subject: [PATCH 1/7] Implement a quantized kernel for einstein summation with exactly two inputs --- sharktank/sharktank/kernels/__init__.py | 1 + .../sharktank/kernels/einsum_2args_q4.py | 266 ++++++++++++++++++ .../kernels/templates/einsum_2args_q4.mlir | 108 +++++++ sharktank/tests/kernels/einsum_q4_test.py | 142 ++++++++++ 4 files changed, 517 insertions(+) create mode 100644 sharktank/sharktank/kernels/einsum_2args_q4.py create mode 100644 sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir create mode 100644 sharktank/tests/kernels/einsum_q4_test.py diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index 308e20ef4..beb7e90a2 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .attention import * +from .einsum_2args_q4 import * from .mmtfp import * from .mmt_block_scaled_offset_q4 import * from .mmt_block_scaled_q8 import * diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py new file mode 100644 index 000000000..75d76b8cd --- /dev/null +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -0,0 +1,266 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .base import * + +import torch + +__all__ = [ + "einsum_2args_q4", +] + + +def einsum_util(einsum_str): + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + es_set = set(es_out) + es_set = es_set.union(es_in0) + es_set = es_set.union(es_in1) + size = len(es_set) + imap = dict() + lmap = dict() + for i in range(len(es_out)): + imap[i] = es_out[i] + lmap[es_out[i]] = i + count = len(es_out) + for c in es_set: + if c not in lmap: + imap[count] = c + lmap[c] = count + count += 1 + + assert count == len(es_set) + + in0_idx = [lmap[i] for i in es_in0] + in1_idx = [lmap[i] for i in es_in1] + out_idx = [lmap[i] for i in es_out] + + input_idx_str = ", ".join(["d" + str(i) for i in range(size)]) + in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx]) + in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx]) + out_idx_str = ", ".join(["d" + str(i) for i in out_idx]) + + iterators = ", ".join( + ['"parallel"' if i in out_idx else '"reduction"' for i in range(size)] + ) + + affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>" + affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>" + affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>" + + indexing_maps = f"""{affine_map_in0}, + {affine_map_in1}, + {affine_map_out} +""" + + out_dyn_dim_size_str = "" + for c in es_out: + if c in es_in0: + out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + "," + elif c in es_in1: + if es_in1.find(c) == len(es_in1) - 1: + out_dyn_dim_size_str += "%b_unblocked_dim," + else: + out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + "," + else: + printf("invalid einsum string") + exit(1) + out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] + return ( + (in0_idx, in1_idx, out_idx), + iterators, + indexing_maps, + out_dyn_dim_size_str, + ) + + +@CustomOp.register(library=LIBRARY) +class einsum_2args_q4(CustomOp): + """Generic block scaled matmul with transposed RHS. + + This corresponds to the BlockScaledLayout and operates on planar `d` + and `qs` tensors as specified there: + + * `d`: `[N, K // BLOCK_SIZE, 1]` + * `qs`: `[N, K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) + * `m`: `[N, K // BLOCK_SIZE, 1]` + + The LHS is expected to be a 3d tensor of shape [B, M, K]. The kernel + will be specialized for all values of N, K and LHS dtype. + """ + + signature = ( + "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" + ) + + def select(self, ksel: KernelSelection): + a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k + d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1] + qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2] + m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] + einsum_str = ksel.attr_str(4).v + + # a arg + a_dims = a_desc.t.shape + torch._check( + a_desc.t.dtype.is_floating_point, + lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})", + ) + + # qs arg + *qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape + block_size = qs_bs_div_2 * 2 + + # d arg + *d_dims, d_group0, d_one = d_desc.t.shape + torch._check( + d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims), + lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})", + ) + + # m arg + *m_dims, m_group0, m_one = m_desc.t.shape + torch._check( + m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), + lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", + ) + + # einsum_str + torch._check( + einsum_str.count(",") == 1 and einsum_str.count("->") == 1, + lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')", + ) + + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + es_set = set(es_out) + + shp = qs_desc.t.shape + print(shp) + b_dims = list(shp[:-2]) + [shp[-2] * block_size] + print(b_dims) + torch._check( + len(es_in0) == len(a_desc.t.shape) + and len(es_in1) + == len(qs_desc.t.shape) + - 1, # The quantized shape is larger until the blocks are collapsed + lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", + ) + torch._check( + len(es_in0) == len(set(es_in0)) + and len(es_in1) == len(set(es_in1)) + and len(es_in0) != 0 + and len(es_in1) != 0, + lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')", + ) + + # Check corresponding dimensions match + for i in range(len(es_in0)): + a_dim = a_dims[i] + c = es_in0[i] + pos = es_in1.find(c) + if pos >= 0: + b_dim = b_dims[pos] + torch._check( + a_dim == b_dim, + lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", + ) + + # Determine the output shape by referencing corresponding input shapes + out_dims = [] + for c in es_out: + pos0 = es_in0.find(c) + pos1 = es_in1.find(c) + a_dim = a_dims[pos0] + b_dim = b_dims[pos1] + if pos0 >= 0: + out_dims.append(a_dim) + elif pos1 >= 0: + out_dims.append(b_dim) + else: + # TODO: I'm not sure if einsum notation supports broadcast in outputs, disabling it for now + torch._check( + False, + lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')", + ) + + # Specialize on BS + qs_desc.specialize_dims(-1) + d_desc.specialize_dims(-1) + m_desc.specialize_dims(-1) + + # Shape batch..., m, n + c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + a = kb.arg_value(0) + a_tensor_type = RankedTensorType(a.type) + d = kb.arg_value(1) + d_tensor_type = RankedTensorType(d.type) + qs = kb.arg_value(2) + qs_tensor_type = RankedTensorType(qs.type) + einsum_str = ksel.arg_descs[4].v + # einsum_str = "mek,menk->men" + + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + + es_name = "_".join([es_in0, es_in1, es_out]) + + ( + (es_0, es_1, es_2), + einsum_iterators, + einsum_indexing_maps, + oddss, + ) = einsum_util(einsum_str) + + rank1 = len(es_1) + dequant_iterators = ", ".join( + ['"parallel"' for i in range(rank1 + 1)] + ) # rank + 1 because of the group dimensions + input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)]) + broadcast_idx_str = ", ".join( + ["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)] + ) + affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>" + affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>" + dequant_indexing_maps = f"""{affine_map_broadcast}, + {affine_map_broadcast}, + {affine_map_parallel}, + {affine_map_parallel}""" + + size_str = "x".join("?" for i in range(rank1 - 2)) + + rank = a_tensor_type.rank + *n_dims, group0, bs_i8 = qs_tensor_type.shape + bs = bs_i8 * 2 # 2 nibbles per byte. + group = group0 * bs + a_type_str = str(a_tensor_type.element_type) + scale_type_str = str(d_tensor_type.element_type) + + template_file = "einsum_2args_q4.mlir" + target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}" + + target_function = inline_template_function( + kb, + template_file, + target_function_name, + bs=bs, + bs_i8=bs_i8, + a_type=a_type_str, + scale_type=scale_type_str, + dequant_indexing_maps=dequant_indexing_maps, + dequant_iterator_types=dequant_iterators, + einsum_indexing_maps=einsum_indexing_maps, + einsum_iterator_types=einsum_iterators, + es_name=es_name, + a_size=len(es_in0), + b_size=len(es_in1), + c_size=len(es_out), + out_dyn_dim_size_str=oddss, + ) + print(target_function) + kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir new file mode 100644 index 000000000..9c5a49fa8 --- /dev/null +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -0,0 +1,108 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +{% set accum_type = "f32" %} + +!lowp_type = i4 +!a_type = {{a_type}} +!scale_type = {{scale_type}} +!accum_type = {{accum_type}} +!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type> +!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8> +!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type> +!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> +!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> +!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type> +!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type> +!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type> +!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type> + +module { + +util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( + %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) + -> !c_tensor_type { + %zero = arith.constant 0.0: !accum_type + // todo: loop + {% for i in range(a_size) %} + %k{{i}} = arith.constant {{i}} : index + {% endfor %} + {% for i in range(a_size, b_size) %} + %k{{i}} = arith.constant {{i}} : index + {% endfor %} + {% for i in range(a_size) %} + %a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type + {% endfor %} + {% for i in range(b_size) %} + %b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type + {% endfor %} + %bs = arith.constant {{bs}} : index + %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index + + //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}} + + // Dequantize. + // todo: loop + %b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type + %b_grouped_dequant = linalg.generic { + indexing_maps = [ + {{dequant_indexing_maps}}], + iterator_types = [{{dequant_iterator_types}}] } + ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) + outs(%b_grouped : !b_grouped_tensor_type) { + ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): + %q_element_ext = arith.extui %q_element : !lowp_type to i32 + %q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type + {% if scale_type == a_type %} + %q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type + %q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type + {% else %} + %d_element_ext = arith.extf %d_element : !scale_type to !a_type + %m_element_ext = arith.extf %m_element : !scale_type to !a_type + %q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type + %q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type + {% endif %} + linalg.yield %q_element_offset : !a_type + } -> !b_grouped_tensor_type + + // Collapse %b to the same unblocked structure. + // todo: loop + %b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type + + // Einsum + // todo: loop, right dimensions + %result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type + %result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type + %result = linalg.generic { + indexing_maps = [ + {{einsum_indexing_maps}}], + iterator_types = [{{einsum_iterator_types}}] } + ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type) + outs(%result_fill : !accum_tensor_type) { + ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): + %bmm_mul = arith.mulf %a_element, %b_element : !a_type + {% if accum_type == a_type %} + %bmm_accum = arith.addf %bmm_mul, %out : !a_type + {% else %} + %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type + %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type + {% endif %} + linalg.yield %bmm_accum : !accum_type + } -> !accum_tensor_type + + // Cast. + // todo: loop, right dimensions + %result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type + %result_cast = linalg.copy + ins(%result : !accum_tensor_type) + outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type + + //iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type] + util.return %result_cast : !c_tensor_type +} + +} diff --git a/sharktank/tests/kernels/einsum_q4_test.py b/sharktank/tests/kernels/einsum_q4_test.py new file mode 100644 index 000000000..be44261a7 --- /dev/null +++ b/sharktank/tests/kernels/einsum_q4_test.py @@ -0,0 +1,142 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest +from parameterized import parameterized + +import torch + +from shark_turbine import aot +from sharktank import kernels +from sharktank.types import layout_utils + + +class einsum_2args_q4_test(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_mk_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 320], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "mk,menk->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_mek_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 4, 320], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "mek,menk->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_me_men_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 4], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "me,men->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(2) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + def testExportDynamicDims(self): + class MyModule(torch.nn.Module): + def forward(self, a, d, qs, m): + return kernels.einsum_2args_q4(a, d, qs, m, "ij,jk->ik") + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + torch.rand([16, 320], dtype=torch.float32), + torch.rand([320, 2, 1], dtype=torch.float16), + (torch.rand([320, 2, 16], dtype=torch.float32) * 32).to(torch.uint8), + torch.rand([320, 2, 1], dtype=torch.float16), + ), + dynamic_shapes={ + "a": {}, + "d": {}, + "qs": {}, + "m": {}, + }, + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn("@sharktank_einsum_2args_q4_ij_jk_ik_32_f32", asm) + + def testExportStaticDims(self): + class MyModule(torch.nn.Module): + def forward(self, a, d, qs, m): + return kernels.einsum_2args_q4(a, d, qs, m, "mek,menk->men") + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + torch.rand([4, 16, 320], dtype=torch.float32), + torch.rand([4, 16, 2, 10, 1], dtype=torch.float16), + (torch.rand([4, 16, 2, 10, 16], dtype=torch.float32) * 32).to( + torch.uint8 + ), + torch.rand([4, 16, 2, 10, 1], dtype=torch.float16), + ), + ) + output = aot.export(ep) + output.print_readable() + output.verify() + asm = str(output.mlir_module) + self.assertIn("@sharktank_einsum_2args_q4_mek_menk_men_32_f32", asm) + + +if __name__ == "__main__": + unittest.main() From fa79dc91375f78142754df2c909c3c7297412f39 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 24 Sep 2024 19:21:12 -0500 Subject: [PATCH 2/7] Hook the einsum op and the get_index` op into their implementations --- .../sharktank/kernels/einsum_2args_q4.py | 4 -- .../kernels/templates/einsum_2args_q4.mlir | 3 +- sharktank/sharktank/ops/custom_impls.py | 13 +++++ sharktank/sharktank/ops/default_impls.py | 36 +++++++++++- sharktank/sharktank/ops/signatures.py | 58 +++++++++++++++++++ sharktank/sharktank/types/tensors.py | 5 ++ 6 files changed, 113 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py index 75d76b8cd..1c96a49e0 100644 --- a/sharktank/sharktank/kernels/einsum_2args_q4.py +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -127,7 +127,6 @@ def select(self, ksel: KernelSelection): m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", ) - # einsum_str torch._check( einsum_str.count(",") == 1 and einsum_str.count("->") == 1, @@ -139,9 +138,7 @@ def select(self, ksel: KernelSelection): es_set = set(es_out) shp = qs_desc.t.shape - print(shp) b_dims = list(shp[:-2]) + [shp[-2] * block_size] - print(b_dims) torch._check( len(es_in0) == len(a_desc.t.shape) and len(es_in1) @@ -262,5 +259,4 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): c_size=len(es_out), out_dyn_dim_size_str=oddss, ) - print(target_function) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir index 9c5a49fa8..cc4fac190 100644 --- a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -25,6 +25,7 @@ module { util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) -> !c_tensor_type { + %debug = tensor.empty() : tensor<1xf32> %zero = arith.constant 0.0: !accum_type // todo: loop {% for i in range(a_size) %} @@ -43,7 +44,7 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type - %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}} + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} // Dequantize. // todo: loop diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index fe6ae27b1..d1df15cbc 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from ..kernels import ( + einsum_2args_q4, mmt_block_scaled_offset_q4_unsigned, mmt_block_scaled_q8, mmtfp, @@ -44,6 +45,18 @@ # return mmtfp(lhs, rhs) +# Einsum + + +@einsum_2args.override(Tensor, QuantizedTensor) +def einsum_2args_QuantizedTensor(input0, input1, einsum_str): + unpacked = input1.unpack() + layout = input1.layout_type + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str) + + # Quantized Matmul diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 050b1384b..936886acd 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -14,7 +14,13 @@ import torch.nn.functional as F from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types import ( + PrimitiveTensor, + QuantizedTensor, + InferenceTensor, + PlanarQuantizedTensor, + BlockScaledI4Layout, +) from ..types.tensors import unbox_tensor, AnyTensor from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * @@ -62,6 +68,12 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) +# Einsum +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor)) +def einsum_2args(x, y, einsum_str): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + + # Elementwise @elementwise.override(Tensor) def elementwise_unary(operator, x, *args, **kwargs): @@ -145,6 +157,28 @@ def flatten_default( return torch.flatten(unbox_tensor(input), start_dim, end_dim) +@get_index.override(AllOfType(Tensor, PrimitiveTensor)) +def get_index_default(tensor, key: slice): + return unbox_tensor(tensor).__get_item__(key) + + +@get_index.override(QuantizedTensor) +def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice): + unpacked = tensor.unpack() + if isinstance(unpacked, BlockScaledI4Layout): + mul = 2 + else: + return NotImplemented + new_d = unpacked._d[key] + new_qs = unpacked._qs[key] + if unpacked.m is not None: + new_m = unpacked.m[key] + dims = new_qs.shape + dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) + layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=dims, layout=layout) + + @gemm.override(AllOfType(Tensor, InferenceTensor)) def gemm( a: AnyTensor, diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index cb7ea7bc4..ed40451ea 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -21,11 +21,13 @@ "all_reduce", "cat", "conv2d", + "einsum_2args", "elementwise", "embedding_lookup", "equal", "expand", "flatten", + "get_index", "gemm", "group_norm_affine", "layer_norm", @@ -165,6 +167,37 @@ def _conv2d_trampoline( d.fail(tensors) +@overridable +def einsum_2args( + input0: AnyTensor, + input1: AnyTensor, + einsum_str: str, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Executes a given Einstein summation notation string on the provided tensors. + + Equivalent to: + ``` + y = torch.einsum(einsum_str, input0, input1) + ``` + """ + raise NotImplementedError + + +@einsum_2args.trampoline +def _einsum_trampoline( + d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str +): + tensors = (input0, input1) + for override in d.find_overrides(tensors): + result = override(input0, input1, einsum_str) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def elementwise(operator, *args, **kwargs) -> AnyTensor: """Applies an elementwise operator against arguments.""" @@ -252,6 +285,7 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): @overridable +<<<<<<< HEAD def expand(tensor: AnyTensor, shape: List[int]) -> AnyTensor: """See torch.Tensor.expand""" ... @@ -264,6 +298,27 @@ def _expand_trampoline( tensors = (tensor,) for override in d.find_overrides(tensors): result = override(tensor, shape) +======= +def get_index( + tensor: AnyTensor, + key: slice, +) -> torch.Tensor: + """Indexes the tensor using the key. + + Equivalent to: + ``` + out = tensor[key] + ``` + """ + raise NotImplementedError + + +@get_index.trampoline +def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, key) +>>>>>>> 8de4a21 (Hook the einsum op and the get_index` op into their implementations) if result is not NotImplemented: return override, result else: @@ -271,6 +326,7 @@ def _expand_trampoline( @overridable +<<<<<<< HEAD def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor: """See torch.flatten""" ... @@ -290,6 +346,8 @@ def _flatten_trampoline( @overridable +======= +>>>>>>> 8de4a21 (Hook the einsum op and the get_index` op into their implementations) def gemm( a: AnyTensor, b: AnyTensor, diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 0fe1f3461..0083e6a2d 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -412,6 +412,11 @@ def __floordiv__(self, rhs): from ..ops import elementwise return elementwise(torch.floor_divide, self, rhs) + + def __getitem__(self, key): + from ..ops import get_index + + return get_index(self, key) REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} From b293169e052bb2d6787688e1d7809ec9f156444b Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 26 Sep 2024 17:43:56 -0500 Subject: [PATCH 3/7] Clean-up todos and other comments --- .../sharktank/kernels/einsum_2args_q4.py | 19 ++++++++----------- .../kernels/templates/einsum_2args_q4.mlir | 5 ----- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py index 1c96a49e0..76d8ad61c 100644 --- a/sharktank/sharktank/kernels/einsum_2args_q4.py +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -66,8 +66,7 @@ def einsum_util(einsum_str): else: out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + "," else: - printf("invalid einsum string") - exit(1) + raise Exception("Invalid einsum string") out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] return ( (in0_idx, in1_idx, out_idx), @@ -79,17 +78,16 @@ def einsum_util(einsum_str): @CustomOp.register(library=LIBRARY) class einsum_2args_q4(CustomOp): - """Generic block scaled matmul with transposed RHS. + """Einsum that takes two tensor inputs and returns one tensor. - This corresponds to the BlockScaledLayout and operates on planar `d` - and `qs` tensors as specified there: + The first input is expected to be a normal tensor. - * `d`: `[N, K // BLOCK_SIZE, 1]` - * `qs`: `[N, K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) - * `m`: `[N, K // BLOCK_SIZE, 1]` + The second input corresponds to the BlockScaledLayout and operates on planar `d` + and `qs` tensors as specified there: - The LHS is expected to be a 3d tensor of shape [B, M, K]. The kernel - will be specialized for all values of N, K and LHS dtype. + * `d`: `[..., K // BLOCK_SIZE, 1]` + * `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) + * `m`: `[..., K // BLOCK_SIZE, 1]` """ signature = ( @@ -178,7 +176,6 @@ def select(self, ksel: KernelSelection): elif pos1 >= 0: out_dims.append(b_dim) else: - # TODO: I'm not sure if einsum notation supports broadcast in outputs, disabling it for now torch._check( False, lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')", diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir index cc4fac190..47ca6b331 100644 --- a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -27,7 +27,6 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( -> !c_tensor_type { %debug = tensor.empty() : tensor<1xf32> %zero = arith.constant 0.0: !accum_type - // todo: loop {% for i in range(a_size) %} %k{{i}} = arith.constant {{i}} : index {% endfor %} @@ -47,7 +46,6 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} // Dequantize. - // todo: loop %b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type %b_grouped_dequant = linalg.generic { indexing_maps = [ @@ -71,11 +69,9 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( } -> !b_grouped_tensor_type // Collapse %b to the same unblocked structure. - // todo: loop %b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type // Einsum - // todo: loop, right dimensions %result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type %result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type %result = linalg.generic { @@ -96,7 +92,6 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( } -> !accum_tensor_type // Cast. - // todo: loop, right dimensions %result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type %result_cast = linalg.copy ins(%result : !accum_tensor_type) From 993466e3819494a160bb61121170c192b18d02d9 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 26 Sep 2024 17:48:21 -0500 Subject: [PATCH 4/7] Replace instances of torch.einsum with ops.einsum_2args for compatibility with other tensor types. --- sharktank/sharktank/layers/ffn_moe_block.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 266537c89..3466e3695 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,6 +12,7 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor +from ..ops import einsum_2args __all__ = [ "FFNMOE", @@ -36,24 +37,24 @@ def __init__( def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): inputs = inputs[:, :] weights = weights[experts, :, :] - matmul = torch.einsum(einstring, inputs, weights.float()) + matmul = einsum_2args(inputs, weights, einstring) return matmul def bigger_mmg(self, inputs, weights, experts): inputs = inputs[:, :] weights = weights[experts, :, :] - matmul = torch.einsum("mek,menk->men", inputs, weights.float()) + matmul = einsum_2args(inputs, weights, "mek,menk->men") return matmul def one_hot_matmul(self, inputs, weights, experts): - matmul = torch.einsum("mk,bnk->bmn", inputs, weights) + matmul = einsum_2args(inputs, weights, "mk,bnk->bmn") # Post mix the experts oh = ( torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8) .transpose(0, 1) .to(torch.float32) ) - output = torch.einsum("bm,bmn->mn", oh, matmul) + output = einsum_2args(oh, matmul, "bm,bmn->mn") return output def forward( @@ -75,7 +76,7 @@ def forward( ffn_down = self.pre_matmul_gather( ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men" ) - ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down) + ffn_down = einsum_2args(expert_gate, ffn_down, "me,men->men") return torch.sum(ffn_down, dim=1) From 2463b2ba6bdc532716883e78970e5e3ea5f34973 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 30 Sep 2024 14:47:14 -0500 Subject: [PATCH 5/7] Fix bad marge --- sharktank/sharktank/ops/default_impls.py | 2 +- sharktank/sharktank/ops/signatures.py | 13 +++++++------ sharktank/tests/kernels/einsum_q4_test.py | 1 - 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 936886acd..a2fcd2813 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -158,7 +158,7 @@ def flatten_default( @get_index.override(AllOfType(Tensor, PrimitiveTensor)) -def get_index_default(tensor, key: slice): +def get_index_default(tensor, key): return unbox_tensor(tensor).__get_item__(key) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index ed40451ea..59f7672c7 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -285,7 +285,6 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): @overridable -<<<<<<< HEAD def expand(tensor: AnyTensor, shape: List[int]) -> AnyTensor: """See torch.Tensor.expand""" ... @@ -298,7 +297,13 @@ def _expand_trampoline( tensors = (tensor,) for override in d.find_overrides(tensors): result = override(tensor, shape) -======= + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable def get_index( tensor: AnyTensor, key: slice, @@ -318,7 +323,6 @@ def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice) tensors = (tensor,) for override in d.find_overrides(tensors): result = override(tensor, key) ->>>>>>> 8de4a21 (Hook the einsum op and the get_index` op into their implementations) if result is not NotImplemented: return override, result else: @@ -326,7 +330,6 @@ def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice) @overridable -<<<<<<< HEAD def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor: """See torch.flatten""" ... @@ -346,8 +349,6 @@ def _flatten_trampoline( @overridable -======= ->>>>>>> 8de4a21 (Hook the einsum op and the get_index` op into their implementations) def gemm( a: AnyTensor, b: AnyTensor, diff --git a/sharktank/tests/kernels/einsum_q4_test.py b/sharktank/tests/kernels/einsum_q4_test.py index be44261a7..d94ec5851 100644 --- a/sharktank/tests/kernels/einsum_q4_test.py +++ b/sharktank/tests/kernels/einsum_q4_test.py @@ -132,7 +132,6 @@ def forward(self, a, d, qs, m): ), ) output = aot.export(ep) - output.print_readable() output.verify() asm = str(output.mlir_module) self.assertIn("@sharktank_einsum_2args_q4_mek_menk_men_32_f32", asm) From fa22631cad18cc0d5655872e4182a0966ceb1344 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 30 Sep 2024 14:50:38 -0500 Subject: [PATCH 6/7] Fix trailing whitespace --- sharktank/sharktank/types/tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 0083e6a2d..f1f93901d 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -412,7 +412,7 @@ def __floordiv__(self, rhs): from ..ops import elementwise return elementwise(torch.floor_divide, self, rhs) - + def __getitem__(self, key): from ..ops import get_index From 4a3bd631ad3f97ed7169eaacf3f64ec291f236b7 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 1 Oct 2024 15:34:54 -0500 Subject: [PATCH 7/7] Fix tensors breaking with einsum_2args changes --- sharktank/sharktank/layers/ffn_moe_block.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 3466e3695..0536302cf 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -64,13 +64,9 @@ def forward( expert_gate: torch.Tensor, ): if self.use_grok: - ffn_gate = F.gelu( - self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts) - ) + ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) else: - ffn_gate = F.silu( - self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts) - ) + ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather(