Skip to content

Commit

Permalink
ONNX Tile operation (#2092)
Browse files Browse the repository at this point in the history
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error

* implementing tile onnx file

* temp

* working implementation and test

* working e2e test

* adding new supported onnx operation to the md file
  • Loading branch information
mepatrick73 authored Aug 7, 2024
1 parent 6b61ad5 commit d770b1f
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 4 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ represent the corresponding Burn Op.
| [Tanh][182] |||
| [TfIdfVectorizer][183] |||
| [ThresholdedRelu][184] |||
| [Tile][185] | ||
| [Tile][185] | ||
| [TopK][186] |||
| [Transpose][187] |||
| [Trilu][188] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ fn main() {
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
18 changes: 18 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ include_models!(
sum,
sum_int,
tanh,
tile,
transpose,
unsqueeze,
unsqueeze_opset11,
Expand Down Expand Up @@ -1712,6 +1713,23 @@ mod tests {
output.to_data().assert_eq(&expected, true);
}

#[test]
fn tile() {
let device = Default::default();
let model: tile::Model<Backend> = tile::Model::new(&device);

let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.]], &device);
let output = model.forward(input).to_data();
let expected = TensorData::from([
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
]);

output.assert_eq(&expected, true);
}

#[test]
fn unsqueeze() {
let device = Default::default();
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/tile/tile.onnx
Binary file not shown.
67 changes: 67 additions & 0 deletions crates/burn-import/onnx-tests/tests/tile/tile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

import onnx
import onnx.helper
import onnx.checker


def build_model():
# Define the input tensor as a graph input
input_tensor = onnx.helper.make_tensor_value_info(
name="input_tensor",
elem_type=onnx.TensorProto.FLOAT,
shape=[2, 2]
)

output_tensor = onnx.helper.make_tensor_value_info(
name="output_tensor",
elem_type=onnx.TensorProto.FLOAT,
shape=[4, 4]
)

# Define the shape tensor for tiling as an initializer
shape_tensor = onnx.helper.make_tensor(
name="shape_tensor",
data_type=onnx.TensorProto.INT64,
dims=[2],
vals=[2, 2]
)
# Create the Tile node
tile_node = onnx.helper.make_node(
"Tile",
inputs=["input_tensor", "shape_tensor"],
outputs=["output_tensor"]
)

# Build the graph
graph = onnx.helper.make_graph(
nodes=[tile_node],
name="main_graph",
inputs=[input_tensor],
outputs=[output_tensor],
initializer=[shape_tensor]
)

# Build the model
model = onnx.helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
)

return model


def main():
onnx_model = build_model()

onnx_model = onnx.shape_inference.infer_shapes(onnx_model)

file_name = "tile.onnx"
onnx.save(onnx_model, file_name)
onnx.checker.check_model(onnx_model)
print(f"ONNX model saved as {file_name}")


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::{
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -113,6 +113,7 @@ pub enum Node<PS: PrecisionSettings> {
Slice(SliceNode),
Squeeze(SqueezeNode),
Sum(SumNode),
Tile(TileNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
Expand Down Expand Up @@ -160,6 +161,7 @@ macro_rules! match_all {
Node::Slice(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Tile(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Expand Down Expand Up @@ -215,6 +217,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Slice(_) => "slice",
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Tile(_) => "tile",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) mod resize;
pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;
Expand Down
97 changes: 97 additions & 0 deletions crates/burn-import/src/burn/node/tile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Config, Debug)]
pub struct TileConfig {
pub repeats: Vec<usize>,
}

#[derive(Debug, Clone, new)]
pub struct TileNode {
pub input: TensorType,
pub output: TensorType,
pub config: TileConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for TileNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

let repeats = self.config.repeats.iter().map(|r| r.to_tokens());

quote! {
let #output = #input.repeat(&[#(#repeats),*]);
}
}

fn into_node(self) -> Node<PS> {
Node::Tile(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, tile::TileConfig, tile::TileNode},
TensorType,
};

#[test]
fn test_codegen_tile() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TileConfig::new(vec![2, 3, 4]);
graph.register(TileNode::new(
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = input.repeat(&[2, 3, 4]);
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
22 changes: 21 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};

use crate::burn::node::pad::PadConfig;
use crate::burn::node::{pad::PadConfig, tile::TileConfig};
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};

/// Create a Conv1dConfig from the attributes of the node
Expand Down Expand Up @@ -745,6 +745,26 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
)
}

/// Create a TileConfig from the attributes of the node
pub fn tile_config(node: &Node) -> TileConfig {
let repeat = node
.inputs
.get(1)
.map(|input| {
if let Some(data) = &input.value {
data.clone()
.into_i64s()
.iter()
.map(|&x| x as usize)
.collect()
} else {
vec![]
}
})
.unwrap_or_default();
TileConfig::new(repeat)
}

/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads(node: &Node) -> Vec<usize> {
Expand Down
13 changes: 12 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use crate::{
slice::SliceNode,
squeeze::SqueezeNode,
sum::SumNode,
tile::TileNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
},
Expand All @@ -66,7 +67,8 @@ use super::op_configuration::{
hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config,
max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config,
reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config,
shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config,
shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config,
unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
Expand Down Expand Up @@ -335,6 +337,7 @@ impl ParsedOnnxGraph {
NodeType::Sign => graph.register(Self::sign_conversion(node)),
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)),
NodeType::Tile => graph.register(Self::tile_conversion(node)),
NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)),
NodeType::ConstantOfShape => {
graph.register(Self::constant_of_shape_conversion(node))
Expand Down Expand Up @@ -1167,6 +1170,14 @@ impl ParsedOnnxGraph {

SqueezeNode::new(input, output, axes)
}

fn tile_conversion(node: Node) -> TileNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = tile_config(&node);

TileNode::new(input, output, config)
}
}

/// Extract data from node states and convert it to `TensorData`.
Expand Down

0 comments on commit d770b1f

Please sign in to comment.