Skip to content

Commit

Permalink
Add interpreter for MapOp (#1373)
Browse files Browse the repository at this point in the history
We have the following constraints in the spec:

```
(I1) inputs: variadic number of tensors.
(I2) dimensions: 1-dimensional tensor constant of type `si64`.
(I3) computation: function.
(C1) All `inputs` and `result` have the same shape.
(C2) size(`inputs`) = N >= 1.
(C3) `dimensions = [0, ..., R-1]`, where `R` = rank(`inputs[0]`).
(C4) `computation` has type `(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>`
where `Ek` = element_type(`inputs[k]`) and `E'` = element_type(`result`).
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) inputs is not a variadic tensor. (Covered by ODS).
I2: a) dimensions is not a 1-dimensional tensor.
    b) element_type(dimensions) != si64. (Covered by ODS).
I3: a) computation is not a function. (Covered by ODS).
C1: a) Not all `inputs` have the same shape. (Covered by ODS).
    b) `inputs` and `result` do not have the same shape. (Covered by ODS).
C2: size(`inputs`) < 1. (Covered by ODS).
C3: a) `dimensions != [0, ..., R-1]`, where `R` = rank(`inputs[0]`).
C4: a) `computation` does not have type `(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>`
where `Ek` = element_type(`inputs[k]`) and `E'` = element_type(`result`).
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
I2a: dimensions is not a 1-dimensional tensor.
C3a: `dimensions != [0, ..., R-1]`, where `R` = rank(`inputs[0]`).
C4a: `computation` does not have type `(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>`
where `Ek` = element_type(`inputs[k]`) and `E'` = element_type(`result`).
```

closes #1106
  • Loading branch information
ghpvnist authored Apr 13, 2023
1 parent e0ec2bb commit c0769c2
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 69 deletions.
13 changes: 7 additions & 6 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -3200,24 +3200,25 @@ and will likely be removed in the future
* (C2) size(`inputs`) $=$ N $\ge$ 1.
* (C3) `dimensions = [0, ..., R-1]`, where `R` $=$ rank(`inputs[0]`).
* (C4) `computation` has type `(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>`
where `Ek` $=$ element_type(`inputs[k]`) and `E'` $=$
element_type(`result`).
where `Ek` $=$ element_type(`inputs[k]`) and `E'` $=$ element_type(`result`).

#### Examples

```mlir
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i32>
stablehlo.return %0 : tensor<i32>
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_map.mlir)

### maximum

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ one of the following tracking labels.
| log | yes | yes | yes | yes | yes |
| log_plus_one | yes | yes | yes | yes | no |
| logistic | yes | yes | yes | yes | yes |
| map | yes | revisit | yes | no | no |
| map | yes | revisit | yes | no | yes |
| maximum | yes | yes | yes | yes | yes |
| minimum | yes | yes | yes | yes | yes |
| multiply | yes | yes | yes | yes | yes |
Expand Down
16 changes: 8 additions & 8 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2375,7 +2375,7 @@ def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size",
}

def StableHLO_MapOp: StableHLO_ShapedInterfaceOp<"map",
[RecursiveMemoryEffects, SameOperandsAndResultShape,
[RecursiveMemoryEffects, SameOperandsAndResultShape /*map_c1, map_c2*/,
SingleBlockImplicitTerminator<"ReturnOp">, InferTensorTypeWithReify]> {
let summary = "Map operation";
let description = [{
Expand All @@ -2388,19 +2388,19 @@ def StableHLO_MapOp: StableHLO_ShapedInterfaceOp<"map",
Example:
```mlir
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i32>
stablehlo.return %0 : tensor<i32>
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
```
}];
let arguments = (ins
Variadic<HLO_Tensor>:$inputs,
I64ElementsAttr:$dimensions
Variadic<HLO_Tensor>:$inputs /*map_i1*/,
I64ElementsAttr:$dimensions /*map_i2*/
);
let regions = (region SizedRegion<1>:$computation);
let regions = (region SizedRegion<1>:$computation /*map_i3*/);
let results = (outs HLO_Tensor);
}

Expand Down
25 changes: 13 additions & 12 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2353,10 +2353,15 @@ LogicalResult inferMapOp(
std::optional<Location> location, ValueRange inputs,
DenseIntElementsAttr dimensions, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// map_i2
if (dimensions.getType().getRank() != 1)
return emitOptionalError(location,
"dimensions should be rank 1 but got rank ",
dimensions.getType().getRank());

if (failed(verifyRegionNotEmpty(location, computation))) return failure();

// Checks if the number of `operands` match the arity of the map `computation`
// region.
// map_c4
auto& computationBlock = computation.front();
auto computationArgs = computationBlock.getArguments();
if (inputs.size() != computationArgs.size())
Expand All @@ -2365,8 +2370,7 @@ LogicalResult inferMapOp(
"map computation, but got: ",
inputs.size(), " and ", computationArgs.size());

// The parameters of computation should all be scalars and match the element
// type of operands.
// map_c4
for (const auto& indexedArg : llvm::enumerate(computationArgs)) {
auto argType = indexedArg.value().getType().dyn_cast<RankedTensorType>();
if (!argType || argType.getRank() != 0)
Expand All @@ -2386,24 +2390,22 @@ LogicalResult inferMapOp(
argType.getElementType());
}

// Mapped computation must return single output
// map_c4
auto computationOutputs = computationBlock.getTerminator()->getOperands();
if (computationOutputs.size() != 1)
return emitOptionalError(location,
"computation must return single output, but got: ",
computationOutputs.size());

// The output of computation must be scalar and have the same element type
// as op result.
// map_c4
auto computationOutputType =
computationOutputs[0].getType().dyn_cast<RankedTensorType>();
if (!computationOutputType || computationOutputType.getRank() != 0)
return emitOptionalError(location,
"computation must return 0-rank tensor, but got: ",
computationOutputs[0].getType());

// Checks that the requested map dimension numbers are monotonically
// increasing.
// map_c3
for (const auto& indexedValue :
llvm::enumerate(dimensions.getValues<int64_t>())) {
if (indexedValue.value() != static_cast<int64_t>(indexedValue.index()))
Expand All @@ -2413,9 +2415,7 @@ LogicalResult inferMapOp(
dimensions);
}

// Checks that number of dimensions of operands matches the size of
// `dimensions` since we currently only support mapping across all
// dimensions: i.e., scalar map functions.
// map_c3
ArrayRef<int64_t> resultShape;
bool allInputsUnranked = true;
for (auto operand : inputs) {
Expand All @@ -2434,6 +2434,7 @@ LogicalResult inferMapOp(
}
}

// map_c4
if (allInputsUnranked)
inferredReturnShapes.emplace_back(computationOutputType.getElementType());
else
Expand Down
21 changes: 21 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ SmallVector<Tensor> eval(
Tensor runtimeResult =
evalLogisticOp(runtimeOperand, logisticOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto mapOp = dyn_cast<MapOp>(op)) {
SmallVector<Tensor> runtimeInputs = scope.find(mapOp.getInputs());
auto runtimeResults = evalMapOp(runtimeInputs, mapOp.getComputation(),
scope, mapOp.getType());
scope.add(op.getResults(), {runtimeResults});
} else if (auto maxOp = dyn_cast<MaxOp>(op)) {
Tensor runtimeLhs = scope.find(maxOp.getLhs());
Tensor runtimeRhs = scope.find(maxOp.getRhs());
Expand Down Expand Up @@ -568,6 +573,22 @@ Tensor evalLogisticOp(const Tensor &operand, ShapedType resultType) {
return result;
}

Tensor evalMapOp(ArrayRef<Tensor> inputs, Region &computation, Scope &scope,
ShapedType resultType) {
Tensor result(resultType);
for (auto resultIt = inputs[0].index_begin();
resultIt != inputs[0].index_end(); ++resultIt) {
SmallVector<Tensor> args;
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = Tensor(computation.getArgument(i).getType());
tensor.set({}, inputs[i].get(*resultIt));
args.push_back(tensor);
}
result.set(*resultIt, eval(computation, args, &scope)[0].get({}));
}
return result;
}

Tensor evalMaxOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it)
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Tensor evalIotaOp(Axis iotaDimension, ShapedType resultType);
Tensor evalIsFiniteOp(const Tensor &operand, ShapedType resultType);
Tensor evalLogOp(const Tensor &operand, ShapedType resultType);
Tensor evalLogisticOp(const Tensor &operand, ShapedType resultType);
Tensor evalMapOp(ArrayRef<Tensor> inputs, Region &computation, Scope &scope,
ShapedType resultType);
Tensor evalMaxOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Tensor evalMinOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Tensor evalMultiplyOp(const Tensor &lhs, const Tensor &rhs,
Expand Down
15 changes: 15 additions & 0 deletions stablehlo/tests/interpret_map.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @map_op_test_si64() {
%input0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>
%input1 = stablehlo.constant dense<[[4, 5], [6, 7]]> : tensor<2x2xi64>
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
check.expect_eq_const %result, dense<[[0, 5], [12, 21]]> : tensor<2x2xi64>
func.return
}
84 changes: 42 additions & 42 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1636,54 +1636,43 @@ func.func @map(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32

// -----

// CHECK-LABEL: func @map_heterogeneous_inputs
func.func @map_heterogeneous_inputs(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2xf32> {
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<i32>):
"stablehlo.return"(%arg2) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xf32>, tensor<2xi32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}

// -----

// CHECK-LABEL: func @map_scalar_operands
func.func @map_scalar_operands(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
func.func @map_c3(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{requires monotonically increasing dimension numbers, but got: dense<[1, 0]> : tensor<2xi64>}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
}) {dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
func.return %0 : tensor<4x5xf32>
}

// -----

// CHECK-LABEL: func @map_unranked
func.func @map_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func.func @map_c3(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{applied to a subset of dimensions currently not supported: operand dimensions = 2, requested map dimensions size = 3}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
func.return %0 : tensor<4x5xf32>
}

// -----

func.func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func.func @map_c4(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// expected-error@+1 {{expects number of operands to match the arity of map computation, but got: 2 and 1}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg: tensor<f32>):
%1 = stablehlo.add %arg, %arg {name = "add"} : tensor<f32>
%1 = stablehlo.add %arg, %arg : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
func.return %0 : tensor<4xf32>
}

// -----

func.func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
func.func @map_c4(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<5xf32>):
Expand All @@ -1695,7 +1684,7 @@ func.func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: ten

// -----

func.func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
func.func @map_c4(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
Expand All @@ -1707,7 +1696,7 @@ func.func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %ar

// -----

func.func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
func.func @map_c4(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{computation must return single output, but got: 0}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
Expand All @@ -1719,7 +1708,7 @@ func.func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg

// -----

func.func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
func.func @map_c4(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
Expand All @@ -1731,38 +1720,49 @@ func.func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: ten

// -----

func.func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{inferred type(s) 'tensor<4x5xi32>' are incompatible with return type(s) of operation 'tensor<4x5xf32>'}}
// CHECK-LABEL: func @map_heterogeneous_inputs
func.func @map_heterogeneous_inputs(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2xf32> {
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<i32>):
"stablehlo.return"(%arg2) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xf32>, tensor<2xi32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}

// -----

func.func @map_i2(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{dimensions should be rank 1 but got rank 2}}
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.constant dense<2> : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
%1 = stablehlo.constant dense<2.0> : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
func.return %0 : tensor<4x5xf32>
}

// -----

func.func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{requires monotonically increasing dimension numbers, but got: dense<[1, 0]> : tensor<2xi64>}}
// CHECK-LABEL: func @map_scalar_operands
func.func @map_scalar_operands(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
func.return %0 : tensor<4x5xf32>
}) {dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// -----

func.func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{applied to a subset of dimensions currently not supported: operand dimensions = 2, requested map dimensions size = 3}}
// CHECK-LABEL: func @map_unranked
func.func @map_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
func.return %0 : tensor<4x5xf32>
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----
Expand Down

0 comments on commit c0769c2

Please sign in to comment.