-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a quantized kernel for einstein summation with exactly two …
…inputs
- Loading branch information
1 parent
89d5d52
commit 6315229
Showing
4 changed files
with
517 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .base import * | ||
|
||
import torch | ||
|
||
__all__ = [ | ||
"einsum_2args_q4", | ||
] | ||
|
||
|
||
def einsum_util(einsum_str): | ||
es_in, es_out = einsum_str.split("->") | ||
es_in0, es_in1 = es_in.split(",") | ||
es_set = set(es_out) | ||
es_set = es_set.union(es_in0) | ||
es_set = es_set.union(es_in1) | ||
size = len(es_set) | ||
imap = dict() | ||
lmap = dict() | ||
for i in range(len(es_out)): | ||
imap[i] = es_out[i] | ||
lmap[es_out[i]] = i | ||
count = len(es_out) | ||
for c in es_set: | ||
if c not in lmap: | ||
imap[count] = c | ||
lmap[c] = count | ||
count += 1 | ||
|
||
assert count == len(es_set) | ||
|
||
in0_idx = [lmap[i] for i in es_in0] | ||
in1_idx = [lmap[i] for i in es_in1] | ||
out_idx = [lmap[i] for i in es_out] | ||
|
||
input_idx_str = ", ".join(["d" + str(i) for i in range(size)]) | ||
in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx]) | ||
in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx]) | ||
out_idx_str = ", ".join(["d" + str(i) for i in out_idx]) | ||
|
||
iterators = ", ".join( | ||
['"parallel"' if i in out_idx else '"reduction"' for i in range(size)] | ||
) | ||
|
||
affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>" | ||
affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>" | ||
affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>" | ||
|
||
indexing_maps = f"""{affine_map_in0}, | ||
{affine_map_in1}, | ||
{affine_map_out} | ||
""" | ||
|
||
out_dyn_dim_size_str = "" | ||
for c in es_out: | ||
if c in es_in0: | ||
out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + "," | ||
elif c in es_in1: | ||
if es_in1.find(c) == len(es_in1) - 1: | ||
out_dyn_dim_size_str += "%b_unblocked_dim," | ||
else: | ||
out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + "," | ||
else: | ||
printf("invalid einsum string") | ||
exit(1) | ||
out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] | ||
return ( | ||
(in0_idx, in1_idx, out_idx), | ||
iterators, | ||
indexing_maps, | ||
out_dyn_dim_size_str, | ||
) | ||
|
||
|
||
@CustomOp.register(library=LIBRARY) | ||
class einsum_2args_q4(CustomOp): | ||
"""Generic block scaled matmul with transposed RHS. | ||
This corresponds to the BlockScaledLayout and operates on planar `d` | ||
and `qs` tensors as specified there: | ||
* `d`: `[N, K // BLOCK_SIZE, 1]` | ||
* `qs`: `[N, K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) | ||
* `m`: `[N, K // BLOCK_SIZE, 1]` | ||
The LHS is expected to be a 3d tensor of shape [B, M, K]. The kernel | ||
will be specialized for all values of N, K and LHS dtype. | ||
""" | ||
|
||
signature = ( | ||
"einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" | ||
) | ||
|
||
def select(self, ksel: KernelSelection): | ||
a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k | ||
d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1] | ||
qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2] | ||
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] | ||
einsum_str = ksel.attr_str(4).v | ||
|
||
# a arg | ||
a_dims = a_desc.t.shape | ||
torch._check( | ||
a_desc.t.dtype.is_floating_point, | ||
lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})", | ||
) | ||
|
||
# qs arg | ||
*qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape | ||
block_size = qs_bs_div_2 * 2 | ||
|
||
# d arg | ||
*d_dims, d_group0, d_one = d_desc.t.shape | ||
torch._check( | ||
d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims), | ||
lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})", | ||
) | ||
|
||
# m arg | ||
*m_dims, m_group0, m_one = m_desc.t.shape | ||
torch._check( | ||
m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), | ||
lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", | ||
) | ||
|
||
# einsum_str | ||
torch._check( | ||
einsum_str.count(",") == 1 and einsum_str.count("->") == 1, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')", | ||
) | ||
|
||
es_in, es_out = einsum_str.split("->") | ||
es_in0, es_in1 = es_in.split(",") | ||
es_set = set(es_out) | ||
|
||
shp = qs_desc.t.shape | ||
print(shp) | ||
b_dims = list(shp[:-2]) + [shp[-2] * block_size] | ||
print(b_dims) | ||
torch._check( | ||
len(es_in0) == len(a_desc.t.shape) | ||
and len(es_in1) | ||
== len(qs_desc.t.shape) | ||
- 1, # The quantized shape is larger until the blocks are collapsed | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", | ||
) | ||
torch._check( | ||
len(es_in0) == len(set(es_in0)) | ||
and len(es_in1) == len(set(es_in1)) | ||
and len(es_in0) != 0 | ||
and len(es_in1) != 0, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')", | ||
) | ||
|
||
# Check corresponding dimensions match | ||
for i in range(len(es_in0)): | ||
a_dim = a_dims[i] | ||
c = es_in0[i] | ||
pos = es_in1.find(c) | ||
if pos >= 0: | ||
b_dim = b_dims[pos] | ||
torch._check( | ||
a_dim == b_dim, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", | ||
) | ||
|
||
# Determine the output shape by referencing corresponding input shapes | ||
out_dims = [] | ||
for c in es_out: | ||
pos0 = es_in0.find(c) | ||
pos1 = es_in1.find(c) | ||
a_dim = a_dims[pos0] | ||
b_dim = b_dims[pos1] | ||
if pos0 >= 0: | ||
out_dims.append(a_dim) | ||
elif pos1 >= 0: | ||
out_dims.append(b_dim) | ||
else: | ||
# TODO: I'm not sure if einsum notation supports broadcast in outputs, disabling it for now | ||
torch._check( | ||
False, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')", | ||
) | ||
|
||
# Specialize on BS | ||
qs_desc.specialize_dims(-1) | ||
d_desc.specialize_dims(-1) | ||
m_desc.specialize_dims(-1) | ||
|
||
# Shape batch..., m, n | ||
c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype) | ||
|
||
def generate(self, ksel: KernelSelection, kb: KernelBuilder): | ||
a = kb.arg_value(0) | ||
a_tensor_type = RankedTensorType(a.type) | ||
d = kb.arg_value(1) | ||
d_tensor_type = RankedTensorType(d.type) | ||
qs = kb.arg_value(2) | ||
qs_tensor_type = RankedTensorType(qs.type) | ||
einsum_str = ksel.arg_descs[4].v | ||
# einsum_str = "mek,menk->men" | ||
|
||
es_in, es_out = einsum_str.split("->") | ||
es_in0, es_in1 = es_in.split(",") | ||
|
||
es_name = "_".join([es_in0, es_in1, es_out]) | ||
|
||
( | ||
(es_0, es_1, es_2), | ||
einsum_iterators, | ||
einsum_indexing_maps, | ||
oddss, | ||
) = einsum_util(einsum_str) | ||
|
||
rank1 = len(es_1) | ||
dequant_iterators = ", ".join( | ||
['"parallel"' for i in range(rank1 + 1)] | ||
) # rank + 1 because of the group dimensions | ||
input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)]) | ||
broadcast_idx_str = ", ".join( | ||
["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)] | ||
) | ||
affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>" | ||
affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>" | ||
dequant_indexing_maps = f"""{affine_map_broadcast}, | ||
{affine_map_broadcast}, | ||
{affine_map_parallel}, | ||
{affine_map_parallel}""" | ||
|
||
size_str = "x".join("?" for i in range(rank1 - 2)) | ||
|
||
rank = a_tensor_type.rank | ||
*n_dims, group0, bs_i8 = qs_tensor_type.shape | ||
bs = bs_i8 * 2 # 2 nibbles per byte. | ||
group = group0 * bs | ||
a_type_str = str(a_tensor_type.element_type) | ||
scale_type_str = str(d_tensor_type.element_type) | ||
|
||
template_file = "einsum_2args_q4.mlir" | ||
target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}" | ||
|
||
target_function = inline_template_function( | ||
kb, | ||
template_file, | ||
target_function_name, | ||
bs=bs, | ||
bs_i8=bs_i8, | ||
a_type=a_type_str, | ||
scale_type=scale_type_str, | ||
dequant_indexing_maps=dequant_indexing_maps, | ||
dequant_iterator_types=dequant_iterators, | ||
einsum_indexing_maps=einsum_indexing_maps, | ||
einsum_iterator_types=einsum_iterators, | ||
es_name=es_name, | ||
a_size=len(es_in0), | ||
b_size=len(es_in1), | ||
c_size=len(es_out), | ||
out_dyn_dim_size_str=oddss, | ||
) | ||
print(target_function) | ||
kb.yield_results(*call_function(target_function, *kb.arg_bindings)) |
108 changes: 108 additions & 0 deletions
108
sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright 2024 Advanced Micro Devices, Inc. | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
{% set accum_type = "f32" %} | ||
|
||
!lowp_type = i4 | ||
!a_type = {{a_type}} | ||
!scale_type = {{scale_type}} | ||
!accum_type = {{accum_type}} | ||
!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type> | ||
!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8> | ||
!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type> | ||
!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> | ||
!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> | ||
!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type> | ||
!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type> | ||
!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type> | ||
!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type> | ||
|
||
module { | ||
|
||
util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( | ||
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) | ||
-> !c_tensor_type { | ||
%zero = arith.constant 0.0: !accum_type | ||
// todo: loop | ||
{% for i in range(a_size) %} | ||
%k{{i}} = arith.constant {{i}} : index | ||
{% endfor %} | ||
{% for i in range(a_size, b_size) %} | ||
%k{{i}} = arith.constant {{i}} : index | ||
{% endfor %} | ||
{% for i in range(a_size) %} | ||
%a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type | ||
{% endfor %} | ||
{% for i in range(b_size) %} | ||
%b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type | ||
{% endfor %} | ||
%bs = arith.constant {{bs}} : index | ||
%b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index | ||
|
||
//%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type | ||
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}} | ||
|
||
// Dequantize. | ||
// todo: loop | ||
%b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type | ||
%b_grouped_dequant = linalg.generic { | ||
indexing_maps = [ | ||
{{dequant_indexing_maps}}], | ||
iterator_types = [{{dequant_iterator_types}}] } | ||
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) | ||
outs(%b_grouped : !b_grouped_tensor_type) { | ||
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): | ||
%q_element_ext = arith.extui %q_element : !lowp_type to i32 | ||
%q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type | ||
{% if scale_type == a_type %} | ||
%q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type | ||
%q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type | ||
{% else %} | ||
%d_element_ext = arith.extf %d_element : !scale_type to !a_type | ||
%m_element_ext = arith.extf %m_element : !scale_type to !a_type | ||
%q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type | ||
%q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type | ||
{% endif %} | ||
linalg.yield %q_element_offset : !a_type | ||
} -> !b_grouped_tensor_type | ||
|
||
// Collapse %b to the same unblocked structure. | ||
// todo: loop | ||
%b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type | ||
|
||
// Einsum | ||
// todo: loop, right dimensions | ||
%result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type | ||
%result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type | ||
%result = linalg.generic { | ||
indexing_maps = [ | ||
{{einsum_indexing_maps}}], | ||
iterator_types = [{{einsum_iterator_types}}] } | ||
ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type) | ||
outs(%result_fill : !accum_tensor_type) { | ||
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): | ||
%bmm_mul = arith.mulf %a_element, %b_element : !a_type | ||
{% if accum_type == a_type %} | ||
%bmm_accum = arith.addf %bmm_mul, %out : !a_type | ||
{% else %} | ||
%bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type | ||
%bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type | ||
{% endif %} | ||
linalg.yield %bmm_accum : !accum_type | ||
} -> !accum_tensor_type | ||
|
||
// Cast. | ||
// todo: loop, right dimensions | ||
%result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type | ||
%result_cast = linalg.copy | ||
ins(%result : !accum_tensor_type) | ||
outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type | ||
|
||
//iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type] | ||
util.return %result_cast : !c_tensor_type | ||
} | ||
|
||
} |
Oops, something went wrong.