Skip to content

Commit

Permalink
Hook the einsum kernel and the __getitem__ operation into their imple…
Browse files Browse the repository at this point in the history
…mentations
  • Loading branch information
KyleHerndon committed Sep 25, 2024
1 parent 6315229 commit 3fdd4fc
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 56 deletions.
79 changes: 29 additions & 50 deletions sharktank/sharktank/kernels/einsum_2args_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,28 @@ def einsum_util(einsum_str):
lmap[es_out[i]] = i
count = len(es_out)
for c in es_set:
if c not in lmap:
if c not in lmap:
imap[count] = c
lmap[c] = count
count += 1

assert count == len(es_set)

assert(count == len(es_set))
in0_idx = [lmap[i] for i in es_in0]
in1_idx = [lmap[i] for i in es_in1]
out_idx = [lmap[i] for i in es_out]

input_idx_str = ", ".join(["d" + str(i) for i in range(size)])
in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx])
in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx])
out_idx_str = ", ".join(["d" + str(i) for i in out_idx])

iterators = ", ".join(
['"parallel"' if i in out_idx else '"reduction"' for i in range(size)]
)


input_idx_str = ", ".join(["d"+str(i) for i in range(size)])
in0_idx_str = ", ".join(["d"+str(i) for i in in0_idx])
in1_idx_str = ", ".join(["d"+str(i) for i in in1_idx])
out_idx_str = ", ".join(["d"+str(i) for i in out_idx])

iterators = ", ".join(["\"parallel\"" if i in out_idx else "\"reduction\"" for i in range(size)])

affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>"
affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>"
affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>"

indexing_maps = f"""{affine_map_in0},
{affine_map_in1},
{affine_map_out}
Expand All @@ -69,13 +67,7 @@ def einsum_util(einsum_str):
printf("invalid einsum string")
exit(1)
out_dyn_dim_size_str = out_dyn_dim_size_str[:-1]
return (
(in0_idx, in1_idx, out_idx),
iterators,
indexing_maps,
out_dyn_dim_size_str,
)

return (in0_idx, in1_idx, out_idx), iterators, (affine_map_in0, affine_map_in1, affine_map_out), indexing_maps, out_dyn_dim_size_str

@CustomOp.register(library=LIBRARY)
class einsum_2args_q4(CustomOp):
Expand All @@ -92,9 +84,7 @@ class einsum_2args_q4(CustomOp):
will be specialized for all values of N, K and LHS dtype.
"""

signature = (
"einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)"
)
signature = "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)"

def select(self, ksel: KernelSelection):
a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k
Expand All @@ -117,20 +107,23 @@ def select(self, ksel: KernelSelection):
# d arg
*d_dims, d_group0, d_one = d_desc.t.shape
torch._check(
d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims),
d_group0 == qs_group0
and d_one == 1
and len(d_dims) == len(qs_dims),
lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})",
)

# m arg
*m_dims, m_group0, m_one = m_desc.t.shape
torch._check(
m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims),
m_desc.t.dtype == d_desc.t.dtype
and len(m_dims) == len(qs_dims),
lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})",
)

# einsum_str
torch._check(
einsum_str.count(",") == 1 and einsum_str.count("->") == 1,
einsum_str.count(",") == 1
and einsum_str.count("->") == 1,
lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')",
)

Expand All @@ -139,14 +132,10 @@ def select(self, ksel: KernelSelection):
es_set = set(es_out)

shp = qs_desc.t.shape
print(shp)
b_dims = list(shp[:-2]) + [shp[-2] * block_size]
print(b_dims)
torch._check(
len(es_in0) == len(a_desc.t.shape)
and len(es_in1)
== len(qs_desc.t.shape)
- 1, # The quantized shape is larger until the blocks are collapsed
and len(es_in1) == len(qs_desc.t.shape) - 1, # The quantized shape is larger until the blocks are collapsed
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})",
)
torch._check(
Expand Down Expand Up @@ -203,28 +192,19 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
qs = kb.arg_value(2)
qs_tensor_type = RankedTensorType(qs.type)
einsum_str = ksel.arg_descs[4].v
# einsum_str = "mek,menk->men"
#einsum_str = "mek,menk->men"

es_in, es_out = einsum_str.split("->")
es_in0, es_in1 = es_in.split(",")

es_name = "_".join([es_in0, es_in1, es_out])

(
(es_0, es_1, es_2),
einsum_iterators,
einsum_indexing_maps,
oddss,
) = einsum_util(einsum_str)


(es_0, es_1, es_2), einsum_iterators, _, einsum_indexing_maps, oddss = einsum_util(einsum_str)

rank1 = len(es_1)
dequant_iterators = ", ".join(
['"parallel"' for i in range(rank1 + 1)]
) # rank + 1 because of the group dimensions
input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)])
broadcast_idx_str = ", ".join(
["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)]
)
dequant_iterators = ", ".join(["\"parallel\"" for i in range(rank1 + 1)]) # rank + 1 because of the group dimensions
input_idx_str = ", ".join(["d"+str(i) for i in range(rank1 + 1)])
broadcast_idx_str = ", ".join(["d"+str(i) if i != rank1 else "0" for i in range(rank1 + 1)])
affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>"
affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>"
dequant_indexing_maps = f"""{affine_map_broadcast},
Expand Down Expand Up @@ -262,5 +242,4 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
c_size=len(es_out),
out_dyn_dim_size_str=oddss,
)
print(target_function)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
5 changes: 3 additions & 2 deletions sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module {
util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%debug = tensor.empty() : tensor<1xf32>
%zero = arith.constant 0.0: !accum_type
// todo: loop
{% for i in range(a_size) %}
Expand All @@ -41,9 +42,9 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
{% endfor %}
%bs = arith.constant {{bs}} : index
%b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index

//%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}}

// Dequantize.
// todo: loop
Expand Down
13 changes: 13 additions & 0 deletions sharktank/sharktank/ops/custom_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F

from ..kernels import (
einsum_2args_q4,
mmt_block_scaled_offset_q4_unsigned,
mmt_block_scaled_q8,
mmtfp,
Expand Down Expand Up @@ -44,6 +45,18 @@
# return mmtfp(lhs, rhs)


# Einsum


@einsum_2args.override(Tensor, QuantizedTensor)
def einsum_2args_QuantizedTensor(input0, input1, einsum_str):
unpacked = input1.unpack()
layout = input1.layout_type
if not isinstance(unpacked, BlockScaledI4Layout):
return NotImplemented
return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str)


# Quantized Matmul


Expand Down
34 changes: 33 additions & 1 deletion sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn.functional as F
from numbers import Number

from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor
from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor, PlanarQuantizedTensor, BlockScaledI4Layout
from ..types.tensors import unbox_tensor, AnyTensor
from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType
from .signatures import *
Expand Down Expand Up @@ -62,6 +62,11 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default)
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)

# Einsum
@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor))
def einsum_2args(x, y, einsum_str):
return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))

# Elementwise
@elementwise.override(Tensor)
def elementwise_unary(operator, x):
Expand Down Expand Up @@ -133,6 +138,33 @@ def equal_default(a, b) -> bool:
return torch.equal(unbox_tensor(a), unbox_tensor(b))


@get_index.override(
AllOfType(Tensor, PrimitiveTensor)
)
def get_index_default(tensor, key: slice):
return unbox_tensor(tensor).__get_item__(key)


@get_index.override(QuantizedTensor)
def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice):
unpacked = tensor.unpack()
if isinstance(unpacked, BlockScaledI4Layout):
mul = 2
else:
return NotImplemented
new_d = unpacked._d[key]
new_qs = unpacked._qs[key]
if unpacked.m is not None:
new_m = unpacked.m[key]
dims = new_qs.shape
dims = dims[:-2] + (dims[-2] * dims[-1] * mul,)
layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=dims, layout=layout)


#get_index.override(PlanarQuantizedTensor, slice)(get_index_QuantizedTensor)


@gemm.override(AllOfType(Tensor, InferenceTensor))
def gemm(
a: AnyTensor,
Expand Down
57 changes: 57 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
"all_reduce",
"cat",
"conv2d",
"einsum_2args",
"elementwise",
"embedding_lookup",
"equal",
"get_index",
"gemm",
"group_norm_affine",
"layer_norm",
Expand Down Expand Up @@ -151,6 +153,35 @@ def _conv2d_trampoline(
d.fail(tensors)


@overridable
def einsum_2args(
input0: AnyTensor,
input1: AnyTensor,
einsum_str: str,
*,
accum_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Executes a given Einstein summation notation string on the provided tensors.
Equivalent to:
```
y = torch.einsum(einsum_str, input0, input1)
```
"""
raise NotImplementedError


@einsum_2args.trampoline
def _einsum_trampoline(d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str):
tensors = (input0, input1)
for override in d.find_overrides(tensors):
result = override(input0, input1, einsum_str)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def elementwise(operator, *args: AnyTensor) -> AnyTensor:
"""Applies an elementwise operator against arguments."""
Expand Down Expand Up @@ -232,6 +263,32 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor):
d.fail(tensors)


@overridable
def get_index(
tensor: AnyTensor,
key: slice,
) -> torch.Tensor:
"""Indexes the tensor using the key.
Equivalent to:
```
out = tensor[key]
```
"""
raise NotImplementedError


@get_index.trampoline
def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice):
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, key)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def gemm(
a: AnyTensor,
Expand Down
8 changes: 5 additions & 3 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ def __rmul__(self, lhs):
# numbers on the lhs.
return self.__mul__(lhs)

def __getitem__(self, key):
from ..ops import get_index
return get_index(self, key)


REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {}

Expand Down Expand Up @@ -433,9 +437,7 @@ def to_planar(self) -> "PlanarQuantizedTensor":
it should override this method to implement properly or raise
NotImplementedError.
"""
return PlanarQuantizedTensor(
name=self.name, shape=self.shape, layout=self.unpack()
)
return PlanarQuantizedTensor(name=self.name, shape=self.shape, layout=self.unpack())

def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata:
"""By default all QuantizedTensors serialize as a generic PlanarQuantizedTensor.
Expand Down

0 comments on commit 3fdd4fc

Please sign in to comment.