Skip to content

Commit

Permalink
Rename function to closely match spec
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Feb 6, 2024
1 parent d50b6fc commit b17ef51
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
4 changes: 4 additions & 0 deletions stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,10 @@ Element convert(Type type, bool value) {
value ? static_cast<uint64_t>(1) : static_cast<uint64_t>(0));
}

Element convert(Type type, APInt value, bool isSigned) {
return convert(type, APSInt(value, isSigned));
}

Element convert(Type type, APSInt value) {
if (isSupportedBooleanType(type)) return Element(type, !value.isZero());
if (isSupportedIntegerType(type))
Expand Down
10 changes: 10 additions & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ Element convert(Type type, uint64_t value);
/// behavior is TBD (#180).
Element convert(Type type, APFloat value);

/// Returns converted Element object of type `type` from source APInt `value`.
/// If the value cannot be exactly represented in the destination type, then the
/// behavior is TBD (#180).
Element convert(Type type, APInt value, bool isSigned = false);

/// Returns converted Element object of type `type` from source APSInt `value`.
/// If the value cannot be exactly represented in the destination type, then the
/// behavior is TBD (#180).
Element convert(Type type, APSInt value);

/// Returns converted Element object of type `type` from source double `value`.
/// If the value cannot be exactly represented in the destination type, then the
/// behavior is TBD (#180).
Expand Down
36 changes: 20 additions & 16 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,17 @@ SmallVector<SmallVector<uint32_t>> getReplicaGroups(
return replicaGroups;
}

Tensor makeScalar(const Element &initValue) {
Tensor constant(Element initValue) {
Tensor result(RankedTensorType::get({}, initValue.getType()));
result.set({}, initValue);
return result;
}

template <typename T>
Tensor constant(T value, Type elementType) {
return constant(convert(elementType, value));
}

Tensor makeSplat(ShapedType type, const Element &initValue) {
Tensor result(type);
for (auto indexIt = result.index_begin(); indexIt != result.index_end();
Expand All @@ -182,9 +187,9 @@ SmallVector<Tensor> split(const Tensor &x, int64_t numResults, Axis axis,
SmallVector<Tensor> results;
for (auto i = 0; i < numResults; ++i) {
SmallVector<Tensor> inputStartIndices(
x.getRank(), makeScalar(convert(IntegerType::get(context, 64), 0.0)));
inputStartIndices[axis] = makeScalar(
convert(IntegerType::get(context, 64), i * resultShape[axis]));
x.getRank(), constant(0.0, IntegerType::get(context, 64)));
inputStartIndices[axis] =
constant(i * resultShape[axis], IntegerType::get(context, 64));

auto result = evalDynamicSliceOp(
x, inputStartIndices, resultShape,
Expand Down Expand Up @@ -909,7 +914,7 @@ Tensor evalAllReduceOp(const Tensor &operand,
++resultIt) {
Tensor resultElement;
for (const auto &groupOperand : groupOperands) {
auto groupOperandElement = makeScalar(groupOperand.get(*resultIt));
auto groupOperandElement = constant(groupOperand.get(*resultIt));
if (resultElement)
resultElement = eval(computation, {resultElement, groupOperandElement},
/*fallback=*/nullptr, process, &scope)[0]
Expand Down Expand Up @@ -1099,9 +1104,8 @@ Tensor evalCollectivePermuteOp(
}

if (result) return result;
return evalBroadcastInDimOp(
makeScalar(convert(operand.getElementType(), 0.0)), {},
operand.getType());
return evalBroadcastInDimOp(constant(0.0, operand.getElementType()), {},
operand.getType());
}

Tensor evalCompareOp(const Tensor &lhs, const Tensor &rhs,
Expand Down Expand Up @@ -1533,7 +1537,7 @@ Tensor evalPartitionIdOp(Process *process, MLIRContext *context) {
"partition_id is only supported when run via interpreter.run_parallel");
auto partitionId = process->getId().partitionId;
auto elementType = IntegerType::get(context, 32, IntegerType::Unsigned);
return makeScalar(Element(elementType, APInt(32, partitionId)));
return constant(APInt(32, partitionId), elementType);
}

Tensor evalPopulationCountOp(const Tensor &operand, ShapedType resultType) {
Expand Down Expand Up @@ -1694,7 +1698,7 @@ Tensor evalReplicaIdOp(Process *process, MLIRContext *context) {
"replica_id is only supported when run via interpreter.run_parallel");
auto replicaId = process->getId().replicaId;
auto elementType = IntegerType::get(context, 32, IntegerType::Unsigned);
return makeScalar(Element(elementType, APInt(32, replicaId)));
return constant(APInt(32, replicaId), elementType);
}

Tensor evalReshapeOp(const Tensor &operand, ShapedType resultType) {
Expand Down Expand Up @@ -1793,9 +1797,9 @@ SmallVector<Tensor> evalScatterOp(

SmallVector<InterpreterValue> updateComputationArgs;
for (const auto &result : results)
updateComputationArgs.push_back(makeScalar(result.get(resultIndex)));
updateComputationArgs.push_back(constant(result.get(resultIndex)));
for (const auto &update : updates)
updateComputationArgs.push_back(makeScalar(update.get(updateIndex)));
updateComputationArgs.push_back(constant(update.get(updateIndex)));

auto updatedValues = eval(updateComputation, updateComputationArgs,
/*fallback=*/nullptr, process, &scope);
Expand Down Expand Up @@ -1834,8 +1838,8 @@ Tensor evalSelectAndScatterOp(const Tensor &operand, const Tensor &source,
selectedIndex = operandIndex;
}

InterpreterValue selectedInterpreterVal(makeScalar(selectedVal.value()));
InterpreterValue currInterpreterVal(makeScalar(currVal));
InterpreterValue selectedInterpreterVal(constant(selectedVal.value()));
InterpreterValue currInterpreterVal(constant(currVal));
auto selectResult =
eval(select, {selectedInterpreterVal, currInterpreterVal},
/*fallback=*/nullptr, process, &scope);
Expand Down Expand Up @@ -1960,8 +1964,8 @@ SmallVector<Tensor> evalSortOp(ArrayRef<Tensor> inputs, Axis dimension,
lhsIndex[adjustedDimension] = lhsHandle;
rhsIndex[adjustedDimension] = rhsHandle;
for (const auto &input : inputs) {
args.emplace_back(makeScalar(input.get(lhsIndex)));
args.emplace_back(makeScalar(input.get(rhsIndex)));
args.emplace_back(constant(input.get(lhsIndex)));
args.emplace_back(constant(input.get(rhsIndex)));
}
auto comparatorResult =
eval(comparator, args, /*fallback=*/nullptr, process, &scope);
Expand Down

0 comments on commit b17ef51

Please sign in to comment.