-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
burn-import: add some tests for ConstantNode #2623
base: main
Are you sure you want to change the base?
Conversation
let device = Default::default(); | ||
let model = constant_f64::Model::<Backend>::new(&device); | ||
let input = TensorData::zeros::<f64, _>(Shape::from([2, 3, 4])); | ||
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below). I wasn't sure how to get PyTorch to just forward the constant by itself so these tests are adding the constant to the input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe by having the output return the constant only? But a simple constant addition works too.
In case you're curious, you could also manually define the onnx graph like the ConstOfShape script. PyTorch tends doesn't always have a 1-to-1 correspondence for ops, so in such cases it could be easier to define the graph manually.
I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below)
The floating point and integer data types are defined by the backend used. A model is not as statically defined like an ONNX graph. If you look at the other tests, the input(s) and output(s) are created using the Tensor
methods, not from TensorData
.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2623 +/- ##
==========================================
+ Coverage 82.06% 82.15% +0.09%
==========================================
Files 831 832 +1
Lines 106003 106749 +746
==========================================
+ Hits 86990 87699 +709
- Misses 19013 19050 +37 ☔ View full report in Codecov by Sentry. |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
I've been trying to implement OneHot ONNX op (#1714) in a WIP draft branch. The ONNX model ends up containing a constant integer vector
values=[0, 1]
used by the OneHot op, this vector was causing issues when trying to test the model. I looked atConstantNode
and these are the tests so far I could get working while investigating.Issues for
ConstantNode
#2624 - constant tensors aren't populated with values
#2625 - generated code for const int tensors doesn't compile
Changes
ConstantNode::tensor_ty_tokens
for tests, but this PR otherwise shouldn't be changing howConstantNode
currently worksTesting
Ran added tests
I checked the .onnx models contain the expected scalar constants using Netron
Screenshots