Skip to content

Commit

Permalink
Fix typo in all_reduce constraint comment (#1860)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist authored Nov 28, 2023
1 parent bfc711d commit 23116b3
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,20 +536,20 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
ArrayRef<int64_t> allowedDimensions) {
int64_t numInputs = inputTypes.size();

// all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2)
return emitOptionalError(loc, "Reduction-region must take ", numInputs * 2,
" parameters, but takes ",
block.getArguments().size(), " parameter(s)");

// all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (block.getTerminator()->getOperands().empty())
return emitOptionalError(
loc, "The reduction-region expected to return some value(s)");

// all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) !=
numInputs)
Expand All @@ -558,7 +558,7 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
block.getTerminator()->getOperands().size(),
" instead");

// all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
Expand All @@ -573,7 +573,7 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
}

for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
// all_reduce_c6, reduce_c2, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c2, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
block.getArgument(inputIdx).getType()))
Expand All @@ -583,7 +583,7 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
block.getArgument(inputIdx).getType(), " vs ",
accumulatorSubShapes[inputIdx]);

// all_reduce_c6, reduce_c2, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c2, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c3, select_and_scatter_c10
if (!compatibleShapeAndElementType(
accumulatorSubShapes[inputIdx],
Expand All @@ -596,7 +596,7 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
block.getArgument(numInputs + inputIdx).getType(), " vs ",
accumulatorSubShapes[inputIdx]);

// all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
initValueTypes[inputIdx],
Expand Down

0 comments on commit 23116b3

Please sign in to comment.