-
Notifications
You must be signed in to change notification settings - Fork 518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(ONNX): avoids resizing fixed dimensions #3945
base: main
Are you sure you want to change the base?
fix(ONNX): avoids resizing fixed dimensions #3945
Conversation
bb3f80f
to
6baa8d5
Compare
6baa8d5
to
ab7e021
Compare
- "result" -> "outputTensor" - "type" -> more like "blueprint" since it includes shape and element data type
ab7e021
to
7aec80b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main structural question is about the need for adding the BaseTensorType
method. If it were useful elsewhere (I have some doubts, since we would need to know too much about the two tensor shapes prior to using it- namely that they are present, and they have the same rank), I would consider keeping it; however, the code is simplified here by not using it, and I suspect that the same would be true in other circumstances where it might be used.
auto this_dimensions = /**/ getSizes(); | ||
auto that_dimensions = that.getSizes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BaseTensorType
might not have sizes, and this will cause a crash when called. I would do:
auto this_dimensions = /**/ getSizes(); | |
auto that_dimensions = that.getSizes(); | |
auto selfSizes = getOptionalSizes(); | |
auto otherSizes = other.getOptionalSizes(); |
Note also the variable and camel casing conventions. The variables self
and other
are used more typically than this
and that
in this codebase (easier to distiguish).
@@ -84,6 +84,10 @@ class BaseTensorType : public Type { | |||
/// Enable isa/dyn_cast for BaseTensorType. | |||
static bool classof(Type type); | |||
|
|||
/// The element-wise comparison of each dimension/size in `that` tensor | |||
std::vector<std::optional<bool>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use SmallVector
instead of std::vector
. The methods are the same, and it is better for small containers like this.
@@ -2686,12 +2686,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |||
}); | |||
patterns.onOp( | |||
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { | |||
Torch::ValueTensorType resultType; | |||
Torch::ValueTensorType outputTensor_blueprint; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the name changing of this variable. This isn't a blueprint, it's the result type.
return rewriter.notifyMatchFailure( | ||
binder.op, "Sizes for batch and channel dimensions must be " | ||
"statically defined"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely do not want to constrain this conversion to static batch and channel dims. This was the reason for needing to write asserts into the helper function getValueList
in the dynamic case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be fine to just put runtime asserts in right after the last match failure. Something like:
Value inputDimZero = rewriter.create<Torch::AtenSizeIntOp>(loc, input, cstZero);
Value inputDimOne = rewriter.create<Torch::AtenSizeIntOp>(loc, input, cstOne);
Value outputDimZero = rewriter.create<Torch::AtenSizeIntOp>(loc, output, cstZero);
Value outputDimOne = rewriter.create<Torch::AtenSizeIntOp>(loc, output, cstOne);
Value cmpDimZero = rewriter.create<Torch::AtenEqIntOp>(loc, inputDimZero, outputDimZero);
Value cmpDimOne = ...
rewriter.create<Torch::RuntimeAssertOp>(loc, cmpDimZero, rewriter.getStringAttr("message"));
// same for DimOne
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, if one of the two dims have input/output sizes that are static and equal, then these asserts will fold out, so there isn't a pressing need to check again for static dims.
for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) { | ||
if (eachDimensionComparison == std::nullopt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you need to loop over the result of the shape comparison anyway, it would be more efficient to not define the helper function at all, and do
for (int64_t dim = 0; dim < 2; dim++) {
if (inputSizes[dim] == Torch::kUnknownSize || outputSizes[dim] == Torch::kUnknownSize)
continue; // you need to implement the runtime asserts, but at least still check the other dim if static.
if (inputSizes[dim] != outputSizes[dim])
return rewriter.notifyMatchFailure(...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely more machine-efficient! That's exactly what I had played around with at first, yeah.
Some general first-principles questions:
- At this level of abstraction, are we still able to optimize for minimal cognitive load at the cost of machine efficiency?
- Or is this already at the level where we gotta optimize for machine runtime, even if it means more CL for the dev?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually suggesting this for both readability and (very modestly) compiler performance. The runtime performance won't be affected either way.
In general, it is good to have runtime performance in mind at this level, but know that many things do indeed get optimized out later on (see, for example, my comment about the folding of runtime asserts for the static dim case). I tend to take a pessimistic view of what will and won't be optimized away, at least when I don't actually know- and will try to generate a cleaner pattern if it doesn't cost a huge amount in code complexity for the perceived benefit.
No description provided.