diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs index eabe77d7f1..00dbd19d13 100644 --- a/crates/burn-import/src/burn/node/pad.rs +++ b/crates/burn-import/src/burn/node/pad.rs @@ -32,7 +32,7 @@ impl NodeCodegen for PadNode { let output = &self.output.name; let pads = self.config.pads.iter().map(|p| p.to_tokens()); - let constant_value_string = format!("{}_f32.elem()", self.config.constant_value); + let constant_value_string = format!("{}_f32", self.config.constant_value); let constant_value = TokenStream::from_str(&constant_value_string).unwrap(); quote! { @@ -42,10 +42,6 @@ impl NodeCodegen for PadNode { fn into_node(self) -> Node { Node::Pad(self) } - - fn register_imports(&self, imports: &mut crate::burn::BurnImports) { - imports.register("burn::tensor::ElementConversion"); - } } #[cfg(test)] @@ -71,7 +67,6 @@ mod tests { graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); let expected = quote! { - use burn::tensor::ElementConversion; use burn::{ module::Module, tensor::{backend::Backend, Tensor}, @@ -93,7 +88,7 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { - let output = input.pad((1, 2, 3, 4), -1_f32.elem()); + let output = input.pad((1, 2, 3, 4), -1_f32); output } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 1f04aeaac8..59dc44b7e6 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -1988,7 +1988,7 @@ where /// fn example>>() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); - /// let tensor = tensor.pad((1, 1, 1, 1), 0.0.into()); + /// let tensor = tensor.pad((1, 1, 1, 1), 0.0); /// println!("{tensor}"); /// // [ /// // [0.0, 0.0, 0.0, 0.0, 0.0], @@ -1998,7 +1998,11 @@ where /// // ] /// } /// ``` - pub fn pad(self, padding: (usize, usize, usize, usize), value: K::Elem) -> Tensor { + pub fn pad( + self, + padding: (usize, usize, usize, usize), + value: E, + ) -> Tensor { let (left, right, top, bottom) = padding; let mut padded_dims: [usize; D] = self.dims();