Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

burn-import: add some tests for ConstantNode #2623

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ fn main() {
.input("tests/clip/clip_opset16.onnx")
.input("tests/clip/clip_opset7.onnx")
.input("tests/concat/concat.onnx")
.input("tests/constant/constant_f32.onnx")
.input("tests/constant/constant_f64.onnx")
.input("tests/constant/constant_i32.onnx")
.input("tests/constant/constant_i64.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
.input("tests/conv1d/conv1d.onnx")
Expand Down
63 changes: 63 additions & 0 deletions crates/burn-import/onnx-tests/tests/constant/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn

CONST_VALUE = 2


class ConstantModel(nn.Module):
def __init__(self, const_dtype: torch.dtype):
super().__init__()
self.const = torch.tensor(CONST_VALUE).to(const_dtype)

def forward(self, x):
return x + self.const


def export_model(model: ConstantModel, dummy_input: torch.Tensor, file_name: str):
model.eval()
torch.onnx.export(
model,
dummy_input,
file_name,
verbose=False,
opset_version=16,
do_constant_folding=False,
)
print(f"Finished exporting model to {file_name}")

# Output some test data for demonstration
test_input = dummy_input.clone()
print(dummy_input.dtype, "test input:", test_input)
output = model.forward(test_input)
print(dummy_input.dtype, "test output:", output)
print("")


def main():
device = torch.device("cpu")
shape = (2, 3, 4)

model_f32 = ConstantModel(torch.float32)
f32_input = torch.randn(shape, dtype=torch.float32, device=device)
export_model(model_f32, f32_input, "constant_f32.onnx")

model_f64 = ConstantModel(torch.float64)
f64_input = torch.randn(shape, dtype=torch.float64, device=device)
export_model(model_f64, f64_input, "constant_f64.onnx")

model_i32 = ConstantModel(torch.int32)
i32_input = torch.randint(
low=-10, high=10, size=shape, device=device, dtype=torch.int32
)
export_model(model_i32, i32_input, "constant_i32.onnx")

model_i64 = ConstantModel(torch.int64)
i64_input = torch.randint(
low=-10, high=10, size=shape, device=device, dtype=torch.int64
)
export_model(model_i64, i64_input, "constant_i64.onnx")

if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
44 changes: 44 additions & 0 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ include_models!(
clip_opset16,
clip_opset7,
concat,
constant_f32,
constant_f64,
constant_i32,
constant_i64,
constant_of_shape,
constant_of_shape_full_like,
conv1d,
Expand Down Expand Up @@ -2183,6 +2187,46 @@ mod tests {
assert_eq!(expected_shape, output.shape());
}

#[test]
fn add_constant_f32() {
let device = Default::default();
let model = constant_f32::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f32, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full::<f32, _>(Shape::from([2, 3, 4]), 2f32);
let output = model.forward(input.into());
assert_eq!(expected_output, output.to_data());
}

#[test]
fn add_constant_f64() {
let device = Default::default();
let model = constant_f64::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f64, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32);
Copy link
Author

@jameshiew jameshiew Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below). I wasn't sure how to get PyTorch to just forward the constant by itself so these tests are adding the constant to the input

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe by having the output return the constant only? But a simple constant addition works too.

In case you're curious, you could also manually define the onnx graph like the ConstOfShape script. PyTorch tends doesn't always have a 1-to-1 correspondence for ops, so in such cases it could be easier to define the graph manually.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below)

The floating point and integer data types are defined by the backend used. A model is not as statically defined like an ONNX graph. If you look at the other tests, the input(s) and output(s) are created using the Tensor methods, not from TensorData.

let output = model.forward(input.into());
assert_eq!(expected_output, output.to_data());
}

#[test]
fn add_constant_i32() {
let device = Default::default();
let model = constant_i32::Model::<Backend>::new(&device);
let input = TensorData::zeros::<i32, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2i64);
let output = model.forward(input.into());
assert_eq!(expected_output, output.to_data());
}

#[test]
fn add_constant_i64() {
let device = Default::default();
let model = constant_i64::Model::<Backend>::new(&device);
let input = TensorData::zeros::<i64, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2i64);
let output = model.forward(input.into());
assert_eq!(expected_output, output.to_data());
}

#[test]
fn constant_of_shape() {
// This tests shape is being passed directly to the model
Expand Down
Loading
Loading