From 612671496886c4cb3c660674894f0cef9a58eabe Mon Sep 17 00:00:00 2001 From: James Hiew Date: Mon, 16 Dec 2024 12:18:19 +0000 Subject: [PATCH 01/17] i64 f32 bool [f32] --- crates/burn-import/src/burn/node/constant.rs | 193 ++++++++++++++++++- 1 file changed, 192 insertions(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index 40db6d260b..18ee925631 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -165,4 +165,195 @@ impl NodeCodegen for ConstantNode { } } -// TODO add test missing for constant node (@antimora 8/2/2023) +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{ + graph::BurnGraph, node::test::assert_tokens, ScalarKind, ScalarType, TensorType, + }; + use burn::record::FullPrecisionSettings; + use burn::tensor::TensorData; + + #[test] + fn test_codegen_constant_scalar_int() { + let mut graph = BurnGraph::::default(); + + graph.register(ConstantNode::new( + "const_int".to_owned(), + ConstantValue::Int64(42i64), + Type::Scalar(ScalarType::new("output", ScalarKind::Int64)), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> i64 { + let output: i64 = 42i64; + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_constant_scalar_float() { + let mut graph = BurnGraph::::default(); + + graph.register(ConstantNode::new( + "const_float".to_owned(), + ConstantValue::Float32(3.14f32), + Type::Scalar(ScalarType::new("output", ScalarKind::Float32)), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> f32 { + let output: f32 = 3.14f32; + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_constant_scalar_bool() { + let mut graph = BurnGraph::::default(); + + graph.register(ConstantNode::new( + "const_bool".to_owned(), + ConstantValue::Bool(true), + Type::Scalar(ScalarType::new("output", ScalarKind::Bool)), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> bool { + let output: bool = true; + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_constant_tensor() { + let mut graph = BurnGraph::::default(); + + let tensor_type = TensorType::new_float_with_shape("const_tensor", 1, Some(vec![1])); + let data = TensorData::from([2f32, 2f32, 2f32, 2f32]); + + graph.register(ConstantNode::new( + "const_tensor".to_owned(), + ConstantValue::Tensor(tensor_type.clone(), data), + Type::Tensor(TensorType::new_float_with_shape("output", 1, Some(vec![1]))), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + const_tensor: burn::module::Param>, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let const_tensor: burn::module::Param> = burn::nn::Initializer::Zeros.init([1], device).set_require_grad(false); + + Self { + const_tensor, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> Tensor { + let output = self.const_tensor.val(); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} From e8f5c0d370a6a2410a87f290582a071fef6db817 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Mon, 16 Dec 2024 13:24:35 +0000 Subject: [PATCH 02/17] Test all types --- crates/burn-import/src/burn/node/constant.rs | 232 +++++++++++++++++-- 1 file changed, 212 insertions(+), 20 deletions(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index 18ee925631..4666de2117 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -175,13 +175,13 @@ mod tests { use burn::tensor::TensorData; #[test] - fn test_codegen_constant_scalar_int() { + fn test_codegen_constant_scalar_float32() { let mut graph = BurnGraph::::default(); graph.register(ConstantNode::new( - "const_int".to_owned(), - ConstantValue::Int64(42i64), - Type::Scalar(ScalarType::new("output", ScalarKind::Int64)), + "const_float32".to_owned(), + ConstantValue::Float32(3.14f32), + Type::Scalar(ScalarType::new("output", ScalarKind::Float32)), )); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -198,7 +198,7 @@ mod tests { device: burn::module::Ignored, } - impl Model { + impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { Self { @@ -208,8 +208,8 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> i64 { - let output: i64 = 42i64; + pub fn forward(&self) -> f32 { + let output: f32 = 3.14f32; output } } @@ -219,13 +219,13 @@ mod tests { } #[test] - fn test_codegen_constant_scalar_float() { + fn test_codegen_constant_scalar_float64() { let mut graph = BurnGraph::::default(); graph.register(ConstantNode::new( - "const_float".to_owned(), - ConstantValue::Float32(3.14f32), - Type::Scalar(ScalarType::new("output", ScalarKind::Float32)), + "const_float64".to_owned(), + ConstantValue::Float64(std::f64::consts::PI), + Type::Scalar(ScalarType::new("output", ScalarKind::Float64)), )); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -242,7 +242,7 @@ mod tests { device: burn::module::Ignored, } - impl Model { + impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { Self { @@ -252,8 +252,96 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> f32 { - let output: f32 = 3.14f32; + pub fn forward(&self) -> f64 { + let output: f64 = 3.141592653589793f64; + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_constant_scalar_int32() { + let mut graph = BurnGraph::::default(); + + graph.register(ConstantNode::new( + "const_int32".to_owned(), + ConstantValue::Int32(123i32), + Type::Scalar(ScalarType::new("output", ScalarKind::Int32)), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> i32 { + let output: i32 = 123i32; + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_constant_scalar_int64() { + let mut graph = BurnGraph::::default(); + + graph.register(ConstantNode::new( + "const_int64".to_owned(), + ConstantValue::Int64(42i64), + Type::Scalar(ScalarType::new("output", ScalarKind::Int64)), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> i64 { + let output: i64 = 42i64; output } } @@ -286,7 +374,7 @@ mod tests { device: burn::module::Ignored, } - impl Model { + impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { Self { @@ -307,16 +395,16 @@ mod tests { } #[test] - fn test_codegen_constant_tensor() { + fn test_codegen_constant_tensor_float() { let mut graph = BurnGraph::::default(); - let tensor_type = TensorType::new_float_with_shape("const_tensor", 1, Some(vec![1])); + let tensor_type = TensorType::new_float_with_shape("const_tensor", 1, Some(vec![4])); let data = TensorData::from([2f32, 2f32, 2f32, 2f32]); graph.register(ConstantNode::new( "const_tensor".to_owned(), ConstantValue::Tensor(tensor_type.clone(), data), - Type::Tensor(TensorType::new_float_with_shape("output", 1, Some(vec![1]))), + Type::Tensor(TensorType::new_float_with_shape("output", 1, Some(vec![4]))), )); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -334,10 +422,10 @@ mod tests { device: burn::module::Ignored, } - impl Model { + impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor: burn::module::Param> = burn::nn::Initializer::Zeros.init([1], device).set_require_grad(false); + let const_tensor: burn::module::Param> = burn::nn::Initializer::Zeros.init([4], device).set_require_grad(false); Self { const_tensor, @@ -356,4 +444,108 @@ mod tests { assert_tokens(graph.codegen(), expected); } + + #[test] + fn test_codegen_constant_tensor_int() { + let mut graph = BurnGraph::::default(); + + let tensor_type = TensorType::new_int_with_shape("const_tensor_int", 1, Some(vec![3])); + let data = TensorData::from([1i32, 2i32, 3i32]); + + graph.register(ConstantNode::new( + "const_tensor_int".to_owned(), + ConstantValue::Tensor(tensor_type.clone(), data), + Type::Tensor(TensorType::new_int_with_shape("output", 1, Some(vec![3]))), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::tensor::Int; + + #[derive(Module, Debug)] + pub struct Model { + const_tensor_int: burn::module::Param>, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let const_tensor_int: burn::module::Param> = burn::nn::Initializer::Zeros.init([3], device).set_require_grad(false); + + Self { + const_tensor_int, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> Tensor { + let output = self.const_tensor_int.val(); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_constant_tensor_bool() { + let mut graph = BurnGraph::::default(); + + let tensor_type = TensorType::new_bool_with_shape("const_tensor_bool", 1, Some(vec![2])); + let data = TensorData::from([true, false]); + + graph.register(ConstantNode::new( + "const_tensor_bool".to_owned(), + ConstantValue::Tensor(tensor_type.clone(), data), + Type::Tensor(TensorType::new_bool_with_shape("output", 1, Some(vec![2]))), + )); + + graph.register_input_output(vec![], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::tensor::Bool; + + #[derive(Module, Debug)] + pub struct Model { + const_tensor_bool: burn::module::Param>, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let const_tensor_bool: burn::module::Param> = burn::nn::Initializer::Zeros.init([2], device).set_require_grad(false); + + Self { + const_tensor_bool, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self) -> Tensor { + let output = self.const_tensor_bool.val(); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } From 5e680f954d7322ca2829e992e52abadfb5a05a62 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Mon, 16 Dec 2024 13:29:27 +0000 Subject: [PATCH 03/17] Add test for 3D tensor --- crates/burn-import/src/burn/node/constant.rs | 282 ++++++++----------- 1 file changed, 116 insertions(+), 166 deletions(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index 4666de2117..cba0e1df3c 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -47,6 +47,17 @@ impl ConstantValue { } } } + + pub fn tensor_ty_tokens(&self) -> TokenStream { + match self { + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + quote! { #ty } + } + _ => panic!("Not a tensor constant"), + } + } + pub fn val_tokens(&self) -> TokenStream { match self { ConstantValue::Float32(val) => quote! { #val }, @@ -173,20 +184,10 @@ mod tests { }; use burn::record::FullPrecisionSettings; use burn::tensor::TensorData; + use quote::ToTokens; - #[test] - fn test_codegen_constant_scalar_float32() { - let mut graph = BurnGraph::::default(); - - graph.register(ConstantNode::new( - "const_float32".to_owned(), - ConstantValue::Float32(3.14f32), - Type::Scalar(ScalarType::new("output", ScalarKind::Float32)), - )); - - graph.register_input_output(vec![], vec!["output".to_string()]); - - let expected = quote! { + fn expected_scalar_snippet(ty: TokenStream, val: TokenStream) -> TokenStream { + quote! { use burn::{ module::Module, tensor::{backend::Backend, Tensor}, @@ -208,158 +209,97 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> f32 { - let output: f32 = 3.14f32; + pub fn forward(&self) -> #ty { + let output: #ty = #val; output } } - }; - - assert_tokens(graph.codegen(), expected); + } } - #[test] - fn test_codegen_constant_scalar_float64() { + fn test_codegen_constant_scalar(value: ConstantValue, scalar_kind: ScalarKind) { let mut graph = BurnGraph::::default(); + let val = value.val_tokens(); + let ty = value.ty_tokens(); graph.register(ConstantNode::new( - "const_float64".to_owned(), - ConstantValue::Float64(std::f64::consts::PI), - Type::Scalar(ScalarType::new("output", ScalarKind::Float64)), + "constant_scalar".to_owned(), + value, + Type::Scalar(ScalarType::new("output", scalar_kind)), )); graph.register_input_output(vec![], vec!["output".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - device: burn::module::Ignored, - } - - impl Model { - #[allow(unused_variables)] - pub fn new(device: &B::Device) -> Self { - Self { - phantom: core::marker::PhantomData, - device: burn::module::Ignored(device.clone()), - } - } - - #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> f64 { - let output: f64 = 3.141592653589793f64; - output - } - } - }; - + let expected = expected_scalar_snippet(ty, val); assert_tokens(graph.codegen(), expected); } #[test] - fn test_codegen_constant_scalar_int32() { - let mut graph = BurnGraph::::default(); - - graph.register(ConstantNode::new( - "const_int32".to_owned(), - ConstantValue::Int32(123i32), - Type::Scalar(ScalarType::new("output", ScalarKind::Int32)), - )); - - graph.register_input_output(vec![], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - device: burn::module::Ignored, - } - - impl Model { - #[allow(unused_variables)] - pub fn new(device: &B::Device) -> Self { - Self { - phantom: core::marker::PhantomData, - device: burn::module::Ignored(device.clone()), - } - } + fn test_codegen_constant_scalar_float32() { + test_codegen_constant_scalar(ConstantValue::Float32(3.14f32), ScalarKind::Float32); + } - #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> i32 { - let output: i32 = 123i32; - output - } - } - }; + #[test] + fn test_codegen_constant_scalar_float64() { + test_codegen_constant_scalar( + ConstantValue::Float64(std::f64::consts::PI), + ScalarKind::Float64, + ); + } - assert_tokens(graph.codegen(), expected); + #[test] + fn test_codegen_constant_scalar_int32() { + test_codegen_constant_scalar(ConstantValue::Int32(123i32), ScalarKind::Int32); } #[test] fn test_codegen_constant_scalar_int64() { - let mut graph = BurnGraph::::default(); - - graph.register(ConstantNode::new( - "const_int64".to_owned(), - ConstantValue::Int64(42i64), - Type::Scalar(ScalarType::new("output", ScalarKind::Int64)), - )); - - graph.register_input_output(vec![], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - device: burn::module::Ignored, - } - - impl Model { - #[allow(unused_variables)] - pub fn new(device: &B::Device) -> Self { - Self { - phantom: core::marker::PhantomData, - device: burn::module::Ignored(device.clone()), - } - } + test_codegen_constant_scalar(ConstantValue::Int64(42i64), ScalarKind::Int64); + } - #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> i64 { - let output: i64 = 42i64; - output - } - } - }; + #[test] + fn test_codegen_constant_scalar_bool() { + test_codegen_constant_scalar(ConstantValue::Bool(true), ScalarKind::Bool); + } - assert_tokens(graph.codegen(), expected); + fn shape_to_tokens(shape: &[usize]) -> TokenStream { + let dims = shape.iter().map(|d| { + let lit = proc_macro2::Literal::usize_unsuffixed(*d); + quote! { #lit } + }); + quote! { [#(#dims),*] } } #[test] - fn test_codegen_constant_scalar_bool() { + fn test_codegen_constant_tensor_float() { let mut graph = BurnGraph::::default(); + let const_tensor = Ident::new("const_tensor", Span::call_site()); + let dimensions = 1; + let shape = vec![4]; + let data = TensorData::from([2f32, 2f32, 2f32, 2f32]); + let tensor_type = TensorType::new_float_with_shape( + const_tensor.to_string(), + dimensions, + Some(shape.clone()), + ); + let value = ConstantValue::Tensor(tensor_type.clone(), data); + graph.register(ConstantNode::new( - "const_bool".to_owned(), - ConstantValue::Bool(true), - Type::Scalar(ScalarType::new("output", ScalarKind::Bool)), + const_tensor.to_string(), + value.clone(), + Type::Tensor(TensorType::new_float_with_shape( + "output", + dimensions, + Some(shape.clone()), + )), )); + let con = const_tensor.to_token_stream(); + let ty = value.ty_tokens(); + let tensor_ty = value.tensor_ty_tokens(); + let shp = shape_to_tokens(&shape); + graph.register_input_output(vec![], vec!["output".to_string()]); let expected = quote! { @@ -370,6 +310,7 @@ mod tests { #[derive(Module, Debug)] pub struct Model { + #con: #ty, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -377,15 +318,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { + let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false); + Self { + #con, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> bool { - let output: bool = true; + pub fn forward(&self) -> #tensor_ty { + let output = self.#con.val(); output } } @@ -395,16 +339,16 @@ mod tests { } #[test] - fn test_codegen_constant_tensor_float() { + fn test_codegen_constant_tensor_int() { let mut graph = BurnGraph::::default(); - let tensor_type = TensorType::new_float_with_shape("const_tensor", 1, Some(vec![4])); - let data = TensorData::from([2f32, 2f32, 2f32, 2f32]); + let tensor_type = TensorType::new_int_with_shape("const_tensor_int", 1, Some(vec![3])); + let data = TensorData::from([1i32, 2i32, 3i32]); graph.register(ConstantNode::new( - "const_tensor".to_owned(), + "const_tensor_int".to_owned(), ConstantValue::Tensor(tensor_type.clone(), data), - Type::Tensor(TensorType::new_float_with_shape("output", 1, Some(vec![4]))), + Type::Tensor(TensorType::new_int_with_shape("output", 1, Some(vec![3]))), )); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -414,10 +358,11 @@ mod tests { module::Module, tensor::{backend::Backend, Tensor}, }; + use burn::tensor::Int; #[derive(Module, Debug)] pub struct Model { - const_tensor: burn::module::Param>, + const_tensor_int: burn::module::Param>, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -425,18 +370,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor: burn::module::Param> = burn::nn::Initializer::Zeros.init([4], device).set_require_grad(false); + let const_tensor_int: burn::module::Param> = burn::nn::Initializer::Zeros.init([3], device).set_require_grad(false); Self { - const_tensor, + const_tensor_int, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> Tensor { - let output = self.const_tensor.val(); + pub fn forward(&self) -> Tensor { + let output = self.const_tensor_int.val(); output } } @@ -446,16 +391,16 @@ mod tests { } #[test] - fn test_codegen_constant_tensor_int() { + fn test_codegen_constant_tensor_bool() { let mut graph = BurnGraph::::default(); - let tensor_type = TensorType::new_int_with_shape("const_tensor_int", 1, Some(vec![3])); - let data = TensorData::from([1i32, 2i32, 3i32]); + let tensor_type = TensorType::new_bool_with_shape("const_tensor_bool", 1, Some(vec![2])); + let data = TensorData::from([true, false]); graph.register(ConstantNode::new( - "const_tensor_int".to_owned(), + "const_tensor_bool".to_owned(), ConstantValue::Tensor(tensor_type.clone(), data), - Type::Tensor(TensorType::new_int_with_shape("output", 1, Some(vec![3]))), + Type::Tensor(TensorType::new_bool_with_shape("output", 1, Some(vec![2]))), )); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -465,11 +410,11 @@ mod tests { module::Module, tensor::{backend::Backend, Tensor}, }; - use burn::tensor::Int; + use burn::tensor::Bool; #[derive(Module, Debug)] pub struct Model { - const_tensor_int: burn::module::Param>, + const_tensor_bool: burn::module::Param>, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -477,18 +422,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor_int: burn::module::Param> = burn::nn::Initializer::Zeros.init([3], device).set_require_grad(false); + let const_tensor_bool: burn::module::Param> = burn::nn::Initializer::Zeros.init([2], device).set_require_grad(false); Self { - const_tensor_int, + const_tensor_bool, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> Tensor { - let output = self.const_tensor_int.val(); + pub fn forward(&self) -> Tensor { + let output = self.const_tensor_bool.val(); output } } @@ -498,16 +443,21 @@ mod tests { } #[test] - fn test_codegen_constant_tensor_bool() { + fn test_codegen_constant_tensor_3d() { let mut graph = BurnGraph::::default(); - let tensor_type = TensorType::new_bool_with_shape("const_tensor_bool", 1, Some(vec![2])); - let data = TensorData::from([true, false]); + let tensor_type = + TensorType::new_bool_with_shape("const_tensor_3d", 3, Some(vec![2, 2, 2])); + let data = TensorData::from([[[true, false], [true, false], [true, false]]]); graph.register(ConstantNode::new( - "const_tensor_bool".to_owned(), + "const_tensor_3d".to_owned(), ConstantValue::Tensor(tensor_type.clone(), data), - Type::Tensor(TensorType::new_bool_with_shape("output", 1, Some(vec![2]))), + Type::Tensor(TensorType::new_bool_with_shape( + "output", + 3, + Some(vec![2, 2, 2]), + )), )); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -521,7 +471,7 @@ mod tests { #[derive(Module, Debug)] pub struct Model { - const_tensor_bool: burn::module::Param>, + const_tensor_3d: burn::module::Param>, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -529,18 +479,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor_bool: burn::module::Param> = burn::nn::Initializer::Zeros.init([2], device).set_require_grad(false); + let const_tensor_3d: burn::module::Param> = burn::nn::Initializer::Zeros.init([2, 2, 2], device).set_require_grad(false); Self { - const_tensor_bool, + const_tensor_3d, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> Tensor { - let output = self.const_tensor_bool.val(); + pub fn forward(&self) -> Tensor { + let output = self.const_tensor_3d.val(); output } } From 34ff5c40ad2c1d0fa1f635346d9af9e25a0d0e47 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Mon, 16 Dec 2024 14:52:17 +0000 Subject: [PATCH 04/17] Rework tensor tests --- crates/burn-import/src/burn/node/constant.rs | 120 +++++++++++++------ 1 file changed, 83 insertions(+), 37 deletions(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index cba0e1df3c..d4522ea818 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -217,6 +217,14 @@ mod tests { } } + fn shape_to_tokens(shape: &[usize]) -> TokenStream { + let dims = shape.iter().map(|d| { + let lit = proc_macro2::Literal::usize_unsuffixed(*d); + quote! { #lit } + }); + quote! { [#(#dims),*] } + } + fn test_codegen_constant_scalar(value: ConstantValue, scalar_kind: ScalarKind) { let mut graph = BurnGraph::::default(); let val = value.val_tokens(); @@ -262,14 +270,6 @@ mod tests { test_codegen_constant_scalar(ConstantValue::Bool(true), ScalarKind::Bool); } - fn shape_to_tokens(shape: &[usize]) -> TokenStream { - let dims = shape.iter().map(|d| { - let lit = proc_macro2::Literal::usize_unsuffixed(*d); - quote! { #lit } - }); - quote! { [#(#dims),*] } - } - #[test] fn test_codegen_constant_tensor_float() { let mut graph = BurnGraph::::default(); @@ -342,15 +342,32 @@ mod tests { fn test_codegen_constant_tensor_int() { let mut graph = BurnGraph::::default(); - let tensor_type = TensorType::new_int_with_shape("const_tensor_int", 1, Some(vec![3])); + let const_tensor = Ident::new("const_tensor_int", Span::call_site()); + let dimensions = 1; + let shape = vec![3]; let data = TensorData::from([1i32, 2i32, 3i32]); + let tensor_type = TensorType::new_int_with_shape( + const_tensor.to_string(), + dimensions, + Some(shape.clone()), + ); + let value = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( - "const_tensor_int".to_owned(), - ConstantValue::Tensor(tensor_type.clone(), data), - Type::Tensor(TensorType::new_int_with_shape("output", 1, Some(vec![3]))), + const_tensor.to_string(), + value.clone(), + Type::Tensor(TensorType::new_int_with_shape( + "output", + dimensions, + Some(shape.clone()), + )), )); + let con = const_tensor.to_token_stream(); + let ty = value.ty_tokens(); + let tensor_ty = value.tensor_ty_tokens(); + let shp = shape_to_tokens(&shape); + graph.register_input_output(vec![], vec!["output".to_string()]); let expected = quote! { @@ -362,7 +379,7 @@ mod tests { #[derive(Module, Debug)] pub struct Model { - const_tensor_int: burn::module::Param>, + #con: #ty, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -370,18 +387,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor_int: burn::module::Param> = burn::nn::Initializer::Zeros.init([3], device).set_require_grad(false); + let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false); Self { - const_tensor_int, + #con, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> Tensor { - let output = self.const_tensor_int.val(); + pub fn forward(&self) -> #tensor_ty { + let output = self.#con.val(); output } } @@ -394,15 +411,32 @@ mod tests { fn test_codegen_constant_tensor_bool() { let mut graph = BurnGraph::::default(); - let tensor_type = TensorType::new_bool_with_shape("const_tensor_bool", 1, Some(vec![2])); + let const_tensor = Ident::new("const_tensor_bool", Span::call_site()); + let dimensions = 1; + let shape = vec![2]; let data = TensorData::from([true, false]); + let tensor_type = TensorType::new_bool_with_shape( + const_tensor.to_string(), + dimensions, + Some(shape.clone()), + ); + let value = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( - "const_tensor_bool".to_owned(), - ConstantValue::Tensor(tensor_type.clone(), data), - Type::Tensor(TensorType::new_bool_with_shape("output", 1, Some(vec![2]))), + const_tensor.to_string(), + value.clone(), + Type::Tensor(TensorType::new_bool_with_shape( + "output", + dimensions, + Some(shape.clone()), + )), )); + let con = const_tensor.to_token_stream(); + let ty = value.ty_tokens(); + let tensor_ty = value.tensor_ty_tokens(); + let shp = shape_to_tokens(&shape); + graph.register_input_output(vec![], vec!["output".to_string()]); let expected = quote! { @@ -414,7 +448,7 @@ mod tests { #[derive(Module, Debug)] pub struct Model { - const_tensor_bool: burn::module::Param>, + #con: #ty, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -422,18 +456,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor_bool: burn::module::Param> = burn::nn::Initializer::Zeros.init([2], device).set_require_grad(false); + let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false); Self { - const_tensor_bool, + #con, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> Tensor { - let output = self.const_tensor_bool.val(); + pub fn forward(&self) -> #tensor_ty { + let output = self.#con.val(); output } } @@ -446,20 +480,32 @@ mod tests { fn test_codegen_constant_tensor_3d() { let mut graph = BurnGraph::::default(); - let tensor_type = - TensorType::new_bool_with_shape("const_tensor_3d", 3, Some(vec![2, 2, 2])); + let const_tensor = Ident::new("const_tensor_3d", Span::call_site()); + let dimensions = 3; + let shape = vec![2, 2, 2]; let data = TensorData::from([[[true, false], [true, false], [true, false]]]); + let tensor_type = TensorType::new_bool_with_shape( + const_tensor.to_string(), + dimensions, + Some(shape.clone()), + ); + let value = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( - "const_tensor_3d".to_owned(), - ConstantValue::Tensor(tensor_type.clone(), data), + const_tensor.to_string(), + value.clone(), Type::Tensor(TensorType::new_bool_with_shape( "output", - 3, - Some(vec![2, 2, 2]), + dimensions, + Some(shape.clone()), )), )); + let con = const_tensor.to_token_stream(); + let ty = value.ty_tokens(); + let tensor_ty = value.tensor_ty_tokens(); + let shp = shape_to_tokens(&shape); + graph.register_input_output(vec![], vec!["output".to_string()]); let expected = quote! { @@ -471,7 +517,7 @@ mod tests { #[derive(Module, Debug)] pub struct Model { - const_tensor_3d: burn::module::Param>, + #con: #ty, phantom: core::marker::PhantomData, device: burn::module::Ignored, } @@ -479,18 +525,18 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let const_tensor_3d: burn::module::Param> = burn::nn::Initializer::Zeros.init([2, 2, 2], device).set_require_grad(false); + let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false); Self { - const_tensor_3d, + #con, phantom: core::marker::PhantomData, device: burn::module::Ignored(device.clone()), } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self) -> Tensor { - let output = self.const_tensor_3d.val(); + pub fn forward(&self) -> #tensor_ty { + let output = self.#con.val(); output } } From b6d6be4faf74487606239bfdf2c470f1a7395d53 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Mon, 16 Dec 2024 15:02:32 +0000 Subject: [PATCH 05/17] Renames --- crates/burn-import/src/burn/node/constant.rs | 84 +++++++++++--------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index d4522ea818..0d2d616e65 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -186,7 +186,11 @@ mod tests { use burn::tensor::TensorData; use quote::ToTokens; - fn expected_scalar_snippet(ty: TokenStream, val: TokenStream) -> TokenStream { + fn expected_tokens_constant_scalar( + ty: TokenStream, + val: TokenStream, + output: TokenStream, + ) -> TokenStream { quote! { use burn::{ module::Module, @@ -210,46 +214,39 @@ mod tests { #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self) -> #ty { - let output: #ty = #val; - output + let #output: #ty = #val; + #output } } } } - fn shape_to_tokens(shape: &[usize]) -> TokenStream { - let dims = shape.iter().map(|d| { - let lit = proc_macro2::Literal::usize_unsuffixed(*d); - quote! { #lit } - }); - quote! { [#(#dims),*] } - } - - fn test_codegen_constant_scalar(value: ConstantValue, scalar_kind: ScalarKind) { + fn assert_codegen_constant_scalar(constant: ConstantValue, scalar_kind: ScalarKind) { let mut graph = BurnGraph::::default(); - let val = value.val_tokens(); - let ty = value.ty_tokens(); + let val = constant.val_tokens(); + let ty = constant.ty_tokens(); + let output = Ident::new("output", Span::call_site()); graph.register(ConstantNode::new( "constant_scalar".to_owned(), - value, - Type::Scalar(ScalarType::new("output", scalar_kind)), + constant, + Type::Scalar(ScalarType::new(output.to_string(), scalar_kind)), )); - graph.register_input_output(vec![], vec!["output".to_string()]); + graph.register_input_output(vec![], vec![output.to_string()]); - let expected = expected_scalar_snippet(ty, val); + let expected = expected_tokens_constant_scalar(ty, val, output.to_token_stream()); assert_tokens(graph.codegen(), expected); } #[test] fn test_codegen_constant_scalar_float32() { - test_codegen_constant_scalar(ConstantValue::Float32(3.14f32), ScalarKind::Float32); + assert_codegen_constant_scalar(ConstantValue::Float32(3.14f32), ScalarKind::Float32); } #[test] fn test_codegen_constant_scalar_float64() { - test_codegen_constant_scalar( + assert_codegen_constant_scalar( ConstantValue::Float64(std::f64::consts::PI), ScalarKind::Float64, ); @@ -257,17 +254,26 @@ mod tests { #[test] fn test_codegen_constant_scalar_int32() { - test_codegen_constant_scalar(ConstantValue::Int32(123i32), ScalarKind::Int32); + assert_codegen_constant_scalar(ConstantValue::Int32(123i32), ScalarKind::Int32); } #[test] fn test_codegen_constant_scalar_int64() { - test_codegen_constant_scalar(ConstantValue::Int64(42i64), ScalarKind::Int64); + assert_codegen_constant_scalar(ConstantValue::Int64(42i64), ScalarKind::Int64); } #[test] fn test_codegen_constant_scalar_bool() { - test_codegen_constant_scalar(ConstantValue::Bool(true), ScalarKind::Bool); + assert_codegen_constant_scalar(ConstantValue::Bool(true), ScalarKind::Bool); + assert_codegen_constant_scalar(ConstantValue::Bool(false), ScalarKind::Bool); + } + + fn shape_to_tokens(shape: &[usize]) -> TokenStream { + let dims = shape.iter().map(|d| { + let lit = proc_macro2::Literal::usize_unsuffixed(*d); + quote! { #lit } + }); + quote! { [#(#dims),*] } } #[test] @@ -283,11 +289,11 @@ mod tests { dimensions, Some(shape.clone()), ); - let value = ConstantValue::Tensor(tensor_type.clone(), data); + let constant = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( const_tensor.to_string(), - value.clone(), + constant.clone(), Type::Tensor(TensorType::new_float_with_shape( "output", dimensions, @@ -296,8 +302,8 @@ mod tests { )); let con = const_tensor.to_token_stream(); - let ty = value.ty_tokens(); - let tensor_ty = value.tensor_ty_tokens(); + let ty = constant.ty_tokens(); + let tensor_ty = constant.tensor_ty_tokens(); let shp = shape_to_tokens(&shape); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -351,11 +357,11 @@ mod tests { dimensions, Some(shape.clone()), ); - let value = ConstantValue::Tensor(tensor_type.clone(), data); + let constant = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( const_tensor.to_string(), - value.clone(), + constant.clone(), Type::Tensor(TensorType::new_int_with_shape( "output", dimensions, @@ -364,8 +370,8 @@ mod tests { )); let con = const_tensor.to_token_stream(); - let ty = value.ty_tokens(); - let tensor_ty = value.tensor_ty_tokens(); + let ty = constant.ty_tokens(); + let tensor_ty = constant.tensor_ty_tokens(); let shp = shape_to_tokens(&shape); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -420,11 +426,11 @@ mod tests { dimensions, Some(shape.clone()), ); - let value = ConstantValue::Tensor(tensor_type.clone(), data); + let constant = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( const_tensor.to_string(), - value.clone(), + constant.clone(), Type::Tensor(TensorType::new_bool_with_shape( "output", dimensions, @@ -433,8 +439,8 @@ mod tests { )); let con = const_tensor.to_token_stream(); - let ty = value.ty_tokens(); - let tensor_ty = value.tensor_ty_tokens(); + let ty = constant.ty_tokens(); + let tensor_ty = constant.tensor_ty_tokens(); let shp = shape_to_tokens(&shape); graph.register_input_output(vec![], vec!["output".to_string()]); @@ -489,11 +495,11 @@ mod tests { dimensions, Some(shape.clone()), ); - let value = ConstantValue::Tensor(tensor_type.clone(), data); + let constant = ConstantValue::Tensor(tensor_type.clone(), data); graph.register(ConstantNode::new( const_tensor.to_string(), - value.clone(), + constant.clone(), Type::Tensor(TensorType::new_bool_with_shape( "output", dimensions, @@ -502,8 +508,8 @@ mod tests { )); let con = const_tensor.to_token_stream(); - let ty = value.ty_tokens(); - let tensor_ty = value.tensor_ty_tokens(); + let ty = constant.ty_tokens(); + let tensor_ty = constant.tensor_ty_tokens(); let shp = shape_to_tokens(&shape); graph.register_input_output(vec![], vec!["output".to_string()]); From 330b376b01ea6844d77b69c45e21a481bba0eee2 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Mon, 16 Dec 2024 15:28:58 +0000 Subject: [PATCH 06/17] Don't use const pi --- crates/burn-import/src/burn/node/constant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index 0d2d616e65..d30e9a8a95 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -247,7 +247,7 @@ mod tests { #[test] fn test_codegen_constant_scalar_float64() { assert_codegen_constant_scalar( - ConstantValue::Float64(std::f64::consts::PI), + ConstantValue::Float64(3.14159265358979323846264338327950288f64), ScalarKind::Float64, ); } From 42d26a5dee9394c962bbd75d366822f6a3a0dea7 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 13:03:44 +0000 Subject: [PATCH 07/17] cargo clippy --fix --- crates/burn-import/src/burn/node/constant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index d30e9a8a95..aed1c483ac 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -247,7 +247,7 @@ mod tests { #[test] fn test_codegen_constant_scalar_float64() { assert_codegen_constant_scalar( - ConstantValue::Float64(3.14159265358979323846264338327950288f64), + ConstantValue::Float64(3.141_592_653_589_793_f64), ScalarKind::Float64, ); } From f01d56995ba07e816ffd3eee2e6f54ecbf2fd6e4 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 13:08:12 +0000 Subject: [PATCH 08/17] Don't use approx pi in example to avoid intellij warning --- crates/burn-import/src/burn/node/constant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index aed1c483ac..ce88e74f68 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -247,7 +247,7 @@ mod tests { #[test] fn test_codegen_constant_scalar_float64() { assert_codegen_constant_scalar( - ConstantValue::Float64(3.141_592_653_589_793_f64), + ConstantValue::Float64(3.111_222_333_444_555_f64), ScalarKind::Float64, ); } From 7c775d12790fe4af52f05ed54fd56638dfed0684 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 13:11:42 +0000 Subject: [PATCH 09/17] Add shape_to_tokens docstring --- crates/burn-import/src/burn/node/constant.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index ce88e74f68..10c55f4407 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -268,6 +268,7 @@ mod tests { assert_codegen_constant_scalar(ConstantValue::Bool(false), ScalarKind::Bool); } + /// Transforms &[1usize, 2usize, 3usize] into literal tokens [1, 2, 3]. fn shape_to_tokens(shape: &[usize]) -> TokenStream { let dims = shape.iter().map(|d| { let lit = proc_macro2::Literal::usize_unsuffixed(*d); From 3bf1a949a1d7b0e018189078dcf70d2fc59b8224 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 13:12:11 +0000 Subject: [PATCH 10/17] Update shape_to_tokens docstring --- crates/burn-import/src/burn/node/constant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index 10c55f4407..cac7e26ee5 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -268,7 +268,7 @@ mod tests { assert_codegen_constant_scalar(ConstantValue::Bool(false), ScalarKind::Bool); } - /// Transforms &[1usize, 2usize, 3usize] into literal tokens [1, 2, 3]. + /// Transforms e.g. `&[1usize, 2usize, 3usize]` into literal tokens [1, 2, 3]. fn shape_to_tokens(shape: &[usize]) -> TokenStream { let dims = shape.iter().map(|d| { let lit = proc_macro2::Literal::usize_unsuffixed(*d); From ac47449bdbe84ac91dcf3e7a272d2def4263a6e3 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 13:16:45 +0000 Subject: [PATCH 11/17] Fix test_codegen_constant_tensor_3d shape --- crates/burn-import/src/burn/node/constant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index cac7e26ee5..572120d2c6 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -489,7 +489,7 @@ mod tests { let const_tensor = Ident::new("const_tensor_3d", Span::call_site()); let dimensions = 3; - let shape = vec![2, 2, 2]; + let shape = vec![1, 3, 2]; let data = TensorData::from([[[true, false], [true, false], [true, false]]]); let tensor_type = TensorType::new_bool_with_shape( const_tensor.to_string(), From 27f1ae080af413e21358c71a90550e60e8b3f018 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 13:42:28 +0000 Subject: [PATCH 12/17] Add test ONNX --- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/constant/constant.onnx | Bin 0 -> 203 bytes .../onnx-tests/tests/constant/constant.py | 36 ++++++++++++++++++ .../burn-import/onnx-tests/tests/test_onnx.rs | 11 ++++++ 4 files changed, 48 insertions(+) create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant.onnx create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant.py diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index f265f3944f..ae258f307a 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -16,6 +16,7 @@ fn main() { .input("tests/clip/clip_opset16.onnx") .input("tests/clip/clip_opset7.onnx") .input("tests/concat/concat.onnx") + .input("tests/constant/constant.onnx") .input("tests/constant_of_shape/constant_of_shape.onnx") .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") .input("tests/conv1d/conv1d.onnx") diff --git a/crates/burn-import/onnx-tests/tests/constant/constant.onnx b/crates/burn-import/onnx-tests/tests/constant/constant.onnx new file mode 100644 index 0000000000000000000000000000000000000000..237292fc19f98e47f2e8c2cda6150b52490952ab GIT binary patch literal 203 zcmd::new(&device); + let input = TensorData::zeros::(Shape::from([2, 3])); + let expected_output = TensorData::full::(Shape::from([2, 3]), 2.0); + let output = model.forward(input.into()); + assert_eq!(expected_output, output.to_data()); + } + #[test] fn constant_of_shape() { // This tests shape is being passed directly to the model From 8dad35f2e9e0411ade898a8900f98ed1fd2c1309 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 14:36:09 +0000 Subject: [PATCH 13/17] Test constant scalar types apart from bool --- crates/burn-import/onnx-tests/build.rs | 5 +- .../onnx-tests/tests/constant/constant.py | 67 ++++++++++++------ .../{constant.onnx => constant_f32.onnx} | Bin 203 -> 211 bytes .../tests/constant/constant_f64.onnx | Bin 0 -> 215 bytes .../tests/constant/constant_i32.onnx | Bin 0 -> 211 bytes .../tests/constant/constant_i64.onnx | Bin 0 -> 215 bytes .../burn-import/onnx-tests/tests/test_onnx.rs | 44 ++++++++++-- 7 files changed, 90 insertions(+), 26 deletions(-) rename crates/burn-import/onnx-tests/tests/constant/{constant.onnx => constant_f32.onnx} (61%) create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_f64.onnx create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_i32.onnx create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_i64.onnx diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index ae258f307a..acdbc466ef 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -16,7 +16,10 @@ fn main() { .input("tests/clip/clip_opset16.onnx") .input("tests/clip/clip_opset7.onnx") .input("tests/concat/concat.onnx") - .input("tests/constant/constant.onnx") + .input("tests/constant/constant_f32.onnx") + .input("tests/constant/constant_f64.onnx") + .input("tests/constant/constant_i32.onnx") + .input("tests/constant/constant_i64.onnx") .input("tests/constant_of_shape/constant_of_shape.onnx") .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") .input("tests/conv1d/conv1d.onnx") diff --git a/crates/burn-import/onnx-tests/tests/constant/constant.py b/crates/burn-import/onnx-tests/tests/constant/constant.py index 3ce4efbfa5..af283ce33e 100644 --- a/crates/burn-import/onnx-tests/tests/constant/constant.py +++ b/crates/burn-import/onnx-tests/tests/constant/constant.py @@ -3,34 +3,61 @@ import torch import torch.nn as nn +CONST_VALUE = 2 + + class ConstantModel(nn.Module): - def __init__(self): - super(ConstantModel, self).__init__() + def __init__(self, const_dtype: torch.dtype): + super().__init__() + self.const = CONST_VALUE def forward(self, x): - # '2.0' should result in a constant node - return x + 2.0 + return x + self.const -def main(): - model = ConstantModel() + +def export_model(model: ConstantModel, dummy_input: torch.Tensor, file_name: str): model.eval() + torch.onnx.export( + model, + dummy_input, + file_name, + verbose=False, + opset_version=16, + do_constant_folding=False, + ) + print(f"Finished exporting model to {file_name}") + + # Output some test data for demonstration + test_input = dummy_input.clone() + print(dummy_input.dtype, "test input:", test_input) + output = model.forward(test_input) + print(dummy_input.dtype, "test output:", output) + print("") + + +def main(): device = torch.device("cpu") - onnx_name = "constant.onnx" + shape = (2, 3, 4) + + model_f32 = ConstantModel(torch.float32) + f32_input = torch.randn(shape, dtype=torch.float32, device=device) + export_model(model_f32, f32_input, "constant_f32.onnx") - # Dummy input for export - dummy_input = torch.randn(3, 4, device=device) - torch.onnx.export(model, dummy_input, onnx_name, - verbose=False, opset_version=16, do_constant_folding=False) + model_f64 = ConstantModel(torch.float64) + f64_input = torch.randn(shape, dtype=torch.float64, device=device) + export_model(model_f64, f64_input, "constant_f64.onnx") - print("Finished exporting model to {}".format(onnx_name)) + model_i32 = ConstantModel(torch.int32) + i32_input = torch.randint( + low=-10, high=10, size=shape, device=device, dtype=torch.int32 + ) + export_model(model_i32, i32_input, "constant_i32.onnx") - # Output some test data for use in testing - input = torch.randn(2, 3, device=device) - print("Test input:", input) - print("Test input data shape: {}".format(input.shape)) - output = model.forward(input) - print("Test output:", output) - print("Test output data shape: {}".format(output.shape)) + model_i64 = ConstantModel(torch.int64) + i64_input = torch.randint( + low=-10, high=10, size=shape, device=device, dtype=torch.int64 + ) + export_model(model_i64, i64_input, "constant_i64.onnx") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/crates/burn-import/onnx-tests/tests/constant/constant.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_f32.onnx similarity index 61% rename from crates/burn-import/onnx-tests/tests/constant/constant.onnx rename to crates/burn-import/onnx-tests/tests/constant/constant_f32.onnx index 237292fc19f98e47f2e8c2cda6150b52490952ab..77e64f21ee3c78da8aeae5a75cd867c1c5a3c7cb 100644 GIT binary patch delta 92 zcmX@jc$razgF}eDpt2;tC^ZaV|z9Y!XgP0s;V`!Vbd# delta 84 zcmcc2c$!g!gF}eDpt2;tC^ YAr3Ak4rU-`NfPE_G(s12ViFJl04l@|X8-^I diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_f64.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_f64.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ed3ed7fe69e880a0fc415cd1c6fc8386ab465646 GIT binary patch literal 215 zcmd8A40ss@=F!=xg literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_i32.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_i32.onnx new file mode 100644 index 0000000000000000000000000000000000000000..48791d1937e01a2aba8be71109c9f07db7258291 GIT binary patch literal 211 zcmdVXciLU65wDH;^AWAU;<)hAZAGt=VCO%CgH>+AOHY}sxX)U literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_i64.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_i64.onnx new file mode 100644 index 0000000000000000000000000000000000000000..49e8ade9a189cfdb6284349535b4d5e9fb91a197 GIT binary patch literal 215 zcmd8A40s!%NFs}dr literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 1d4e4fb719..06ca7ee600 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -25,7 +25,10 @@ include_models!( clip_opset16, clip_opset7, concat, - constant, + constant_f32, + constant_f64, + constant_i32, + constant_i64, constant_of_shape, constant_of_shape_full_like, conv1d, @@ -136,6 +139,7 @@ mod tests { use burn::tensor::{Bool, Int, Shape, Tensor, TensorData}; + use burn_ndarray::NdArray; use float_cmp::ApproxEq; type Backend = burn_ndarray::NdArray; @@ -2185,11 +2189,41 @@ mod tests { } #[test] - fn constant() { + fn constant_f32() { let device = Default::default(); - let model = constant::Model::::new(&device); - let input = TensorData::zeros::(Shape::from([2, 3])); - let expected_output = TensorData::full::(Shape::from([2, 3]), 2.0); + let model = constant_f32::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 3, 4])); + let expected_output = TensorData::full::(Shape::from([2, 3, 4]), 2f32); + let output = model.forward(input.into()); + assert_eq!(expected_output, output.to_data()); + } + + #[test] + fn constant_f64() { + let device = Default::default(); + let model = constant_f64::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 3, 4])); + let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32); + let output = model.forward(input.into()); + assert_eq!(expected_output, output.to_data()); + } + + #[test] + fn constant_i32() { + let device = Default::default(); + let model = constant_i32::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 3, 4])); + let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2i64); + let output = model.forward(input.into()); + assert_eq!(expected_output, output.to_data()); + } + + #[test] + fn constant_i64() { + let device = Default::default(); + let model = constant_i64::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 3, 4])); + let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2i64); let output = model.forward(input.into()); assert_eq!(expected_output, output.to_data()); } From 7e82c377c95bed70cf83d3fd3e4c82cbfc689a75 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 15:25:40 +0000 Subject: [PATCH 14/17] WIP constant tensor --- crates/burn-import/onnx-tests/build.rs | 2 + .../onnx-tests/tests/constant/constant.py | 2 +- .../tests/constant/constant_tensor.py | 54 ++++++++++++++++++ .../tests/constant/constant_tensor_f32.onnx | Bin 0 -> 219 bytes .../tests/constant/constant_tensor_f64.onnx | Bin 0 -> 235 bytes .../burn-import/onnx-tests/tests/test_onnx.rs | 12 ++++ 6 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_tensor.py create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_tensor_f32.onnx create mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_tensor_f64.onnx diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index acdbc466ef..dd996969fc 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -20,6 +20,8 @@ fn main() { .input("tests/constant/constant_f64.onnx") .input("tests/constant/constant_i32.onnx") .input("tests/constant/constant_i64.onnx") + .input("tests/constant/constant_tensor_f32.onnx") + .input("tests/constant/constant_tensor_f64.onnx") .input("tests/constant_of_shape/constant_of_shape.onnx") .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") .input("tests/conv1d/conv1d.onnx") diff --git a/crates/burn-import/onnx-tests/tests/constant/constant.py b/crates/burn-import/onnx-tests/tests/constant/constant.py index af283ce33e..d6f1a8b0ea 100644 --- a/crates/burn-import/onnx-tests/tests/constant/constant.py +++ b/crates/burn-import/onnx-tests/tests/constant/constant.py @@ -9,7 +9,7 @@ class ConstantModel(nn.Module): def __init__(self, const_dtype: torch.dtype): super().__init__() - self.const = CONST_VALUE + self.const = torch.tensor(CONST_VALUE).to(const_dtype) def forward(self, x): return x + self.const diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_tensor.py b/crates/burn-import/onnx-tests/tests/constant/constant_tensor.py new file mode 100644 index 0000000000..c92ef430cb --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/constant/constant_tensor.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn + +CONST_VALUE = torch.tensor([[2, 2], + [2, 2]]) +CONST_SHAPE = CONST_VALUE.shape + +class ConstantTensorModel(nn.Module): + def __init__(self, const_dtype: torch.dtype): + super().__init__() + self.const_tensor = CONST_VALUE.to(const_dtype) + + def forward(self, x): + return self.const_tensor + x + + +def export_model(model: nn.Module, dummy_input: torch.Tensor, file_name: str): + model.eval() + torch.onnx.export( + model, + dummy_input, + file_name, + verbose=False, + opset_version=16, + do_constant_folding=False, + ) + print(f"Finished exporting model to {file_name}") + + # Output some test data for demonstration + test_input = dummy_input.clone() + print(dummy_input.dtype, "test input:", test_input) + output = model.forward(test_input) + print(dummy_input.dtype, "test output:", output) + print("") + + +def main(): + device = torch.device("cpu") + + # Export with a float32 tensor constant + model_f32 = ConstantTensorModel(torch.float32) + f32_input = torch.randn(CONST_SHAPE, dtype=torch.float32, device=device) + export_model(model_f32, f32_input, "constant_tensor_f32.onnx") + + # Export with a float64 tensor constant + model_f64 = ConstantTensorModel(torch.float64) + f64_input = torch.randn(CONST_SHAPE, dtype=torch.float64, device=device) + export_model(model_f64, f64_input, "constant_tensor_f64.onnx") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f32.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f32.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e908564881dcd185d0df93cd07a065023dc404e8 GIT binary patch literal 219 zcmdN&2{3vIFfcGUAmas$EL{3rSPbOi&dn1O^4S8ifvUVM5{VnIffEUH32E*=g>Ar3B}bs(H1%*ALVgeK_3Bp?6)GMq7u literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f64.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f64.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2e46918bcde30030769f1ee048f3911a5899033f GIT binary patch literal 235 zcmdN&32=KUFhGC%S@ z^npwzW*{NNm7AEE7oT2~SdbAVi>i>1i-&_-h=U7gJqRZWb1@nTp$R%M2?ziH>(?>+ literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 06ca7ee600..f5854f51af 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -29,6 +29,8 @@ include_models!( constant_f64, constant_i32, constant_i64, + constant_tensor_f32, + constant_tensor_f64, constant_of_shape, constant_of_shape_full_like, conv1d, @@ -2228,6 +2230,16 @@ mod tests { assert_eq!(expected_output, output.to_data()); } + #[test] + fn constant_tensor_f32() { + let device = Default::default(); + let model = constant_tensor_f32::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 2])); + let expected_output = TensorData::full::(Shape::from([2, 2]), 2f32); + let output = model.forward(input.into()); + assert_eq!(expected_output, output.to_data()); + } + #[test] fn constant_of_shape() { // This tests shape is being passed directly to the model From a07d099c64381ed60dd188af7b48e515208d4527 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 15:27:47 +0000 Subject: [PATCH 15/17] Remove tensor stuff --- crates/burn-import/onnx-tests/build.rs | 2 - .../tests/constant/constant_tensor.py | 54 ------------------ .../tests/constant/constant_tensor_f32.onnx | Bin 219 -> 0 bytes .../tests/constant/constant_tensor_f64.onnx | Bin 235 -> 0 bytes .../burn-import/onnx-tests/tests/test_onnx.rs | 12 ---- 5 files changed, 68 deletions(-) delete mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_tensor.py delete mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_tensor_f32.onnx delete mode 100644 crates/burn-import/onnx-tests/tests/constant/constant_tensor_f64.onnx diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index dd996969fc..acdbc466ef 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -20,8 +20,6 @@ fn main() { .input("tests/constant/constant_f64.onnx") .input("tests/constant/constant_i32.onnx") .input("tests/constant/constant_i64.onnx") - .input("tests/constant/constant_tensor_f32.onnx") - .input("tests/constant/constant_tensor_f64.onnx") .input("tests/constant_of_shape/constant_of_shape.onnx") .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") .input("tests/conv1d/conv1d.onnx") diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_tensor.py b/crates/burn-import/onnx-tests/tests/constant/constant_tensor.py deleted file mode 100644 index c92ef430cb..0000000000 --- a/crates/burn-import/onnx-tests/tests/constant/constant_tensor.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 - -import torch -import torch.nn as nn - -CONST_VALUE = torch.tensor([[2, 2], - [2, 2]]) -CONST_SHAPE = CONST_VALUE.shape - -class ConstantTensorModel(nn.Module): - def __init__(self, const_dtype: torch.dtype): - super().__init__() - self.const_tensor = CONST_VALUE.to(const_dtype) - - def forward(self, x): - return self.const_tensor + x - - -def export_model(model: nn.Module, dummy_input: torch.Tensor, file_name: str): - model.eval() - torch.onnx.export( - model, - dummy_input, - file_name, - verbose=False, - opset_version=16, - do_constant_folding=False, - ) - print(f"Finished exporting model to {file_name}") - - # Output some test data for demonstration - test_input = dummy_input.clone() - print(dummy_input.dtype, "test input:", test_input) - output = model.forward(test_input) - print(dummy_input.dtype, "test output:", output) - print("") - - -def main(): - device = torch.device("cpu") - - # Export with a float32 tensor constant - model_f32 = ConstantTensorModel(torch.float32) - f32_input = torch.randn(CONST_SHAPE, dtype=torch.float32, device=device) - export_model(model_f32, f32_input, "constant_tensor_f32.onnx") - - # Export with a float64 tensor constant - model_f64 = ConstantTensorModel(torch.float64) - f64_input = torch.randn(CONST_SHAPE, dtype=torch.float64, device=device) - export_model(model_f64, f64_input, "constant_tensor_f64.onnx") - - -if __name__ == "__main__": - main() diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f32.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f32.onnx deleted file mode 100644 index e908564881dcd185d0df93cd07a065023dc404e8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 219 zcmdN&2{3vIFfcGUAmas$EL{3rSPbOi&dn1O^4S8ifvUVM5{VnIffEUH32E*=g>Ar3B}bs(H1%*ALVgeK_3Bp?6)GMq7u diff --git a/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f64.onnx b/crates/burn-import/onnx-tests/tests/constant/constant_tensor_f64.onnx deleted file mode 100644 index 2e46918bcde30030769f1ee048f3911a5899033f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 235 zcmdN&32=KUFhGC%S@ z^npwzW*{NNm7AEE7oT2~SdbAVi>i>1i-&_-h=U7gJqRZWb1@nTp$R%M2?ziH>(?>+ diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index f5854f51af..06ca7ee600 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -29,8 +29,6 @@ include_models!( constant_f64, constant_i32, constant_i64, - constant_tensor_f32, - constant_tensor_f64, constant_of_shape, constant_of_shape_full_like, conv1d, @@ -2230,16 +2228,6 @@ mod tests { assert_eq!(expected_output, output.to_data()); } - #[test] - fn constant_tensor_f32() { - let device = Default::default(); - let model = constant_tensor_f32::Model::::new(&device); - let input = TensorData::zeros::(Shape::from([2, 2])); - let expected_output = TensorData::full::(Shape::from([2, 2]), 2f32); - let output = model.forward(input.into()); - assert_eq!(expected_output, output.to_data()); - } - #[test] fn constant_of_shape() { // This tests shape is being passed directly to the model From 5efd749f9b6d5a3a1d08711c0086123b0ea0db72 Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 15:28:58 +0000 Subject: [PATCH 16/17] cargo fmt --- crates/burn-import/onnx-tests/tests/test_onnx.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 06ca7ee600..bda1fbd93b 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -139,7 +139,7 @@ mod tests { use burn::tensor::{Bool, Int, Shape, Tensor, TensorData}; - use burn_ndarray::NdArray; + use float_cmp::ApproxEq; type Backend = burn_ndarray::NdArray; From 756fd3c82a65d152596a836608169f1adbdfa31c Mon Sep 17 00:00:00 2001 From: James Hiew Date: Tue, 17 Dec 2024 15:40:46 +0000 Subject: [PATCH 17/17] Rename tests --- crates/burn-import/onnx-tests/tests/test_onnx.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index bda1fbd93b..391b6bfdef 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -139,7 +139,6 @@ mod tests { use burn::tensor::{Bool, Int, Shape, Tensor, TensorData}; - use float_cmp::ApproxEq; type Backend = burn_ndarray::NdArray; @@ -2189,7 +2188,7 @@ mod tests { } #[test] - fn constant_f32() { + fn add_constant_f32() { let device = Default::default(); let model = constant_f32::Model::::new(&device); let input = TensorData::zeros::(Shape::from([2, 3, 4])); @@ -2199,7 +2198,7 @@ mod tests { } #[test] - fn constant_f64() { + fn add_constant_f64() { let device = Default::default(); let model = constant_f64::Model::::new(&device); let input = TensorData::zeros::(Shape::from([2, 3, 4])); @@ -2209,7 +2208,7 @@ mod tests { } #[test] - fn constant_i32() { + fn add_constant_i32() { let device = Default::default(); let model = constant_i32::Model::::new(&device); let input = TensorData::zeros::(Shape::from([2, 3, 4])); @@ -2219,7 +2218,7 @@ mod tests { } #[test] - fn constant_i64() { + fn add_constant_i64() { let device = Default::default(); let model = constant_i64::Model::::new(&device); let input = TensorData::zeros::(Shape::from([2, 3, 4]));