Skip to content

Commit

Permalink
Change pad value w/ ElementConversion (#2653)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 3, 2025
1 parent d9418ad commit 2338912
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
9 changes: 2 additions & 7 deletions crates/burn-import/src/burn/node/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PadNode {
let output = &self.output.name;

let pads = self.config.pads.iter().map(|p| p.to_tokens());
let constant_value_string = format!("{}_f32.elem()", self.config.constant_value);
let constant_value_string = format!("{}_f32", self.config.constant_value);
let constant_value = TokenStream::from_str(&constant_value_string).unwrap();

quote! {
Expand All @@ -42,10 +42,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PadNode {
fn into_node(self) -> Node<PS> {
Node::Pad(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
}
}

#[cfg(test)]
Expand All @@ -71,7 +67,6 @@ mod tests {
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::tensor::ElementConversion;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
Expand All @@ -93,7 +88,7 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.pad((1, 2, 3, 4), -1_f32.elem());
let output = input.pad((1, 2, 3, 4), -1_f32);
output
}
}
Expand Down
8 changes: 6 additions & 2 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1988,7 +1988,7 @@ where
/// fn example<B: Backend<FloatElem: From<f32>>>() {
/// let device = B::Device::default();
/// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
/// let tensor = tensor.pad((1, 1, 1, 1), 0.0.into());
/// let tensor = tensor.pad((1, 1, 1, 1), 0.0);
/// println!("{tensor}");
/// // [
/// // [0.0, 0.0, 0.0, 0.0, 0.0],
Expand All @@ -1998,7 +1998,11 @@ where
/// // ]
/// }
/// ```
pub fn pad(self, padding: (usize, usize, usize, usize), value: K::Elem) -> Tensor<B, D, K> {
pub fn pad<E: ElementConversion>(
self,
padding: (usize, usize, usize, usize),
value: E,
) -> Tensor<B, D, K> {
let (left, right, top, bottom) = padding;

let mut padded_dims: [usize; D] = self.dims();
Expand Down

0 comments on commit 2338912

Please sign in to comment.