From e14299542a3b8e91a181ecb992de49b5a2420e56 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 10 Jan 2025 12:53:51 -0500 Subject: [PATCH] Check only for same padding (doh) --- crates/burn-core/src/nn/conv/conv1d.rs | 4 +++- crates/burn-core/src/nn/conv/conv2d.rs | 4 +++- crates/burn-core/src/nn/conv/conv3d.rs | 4 +++- crates/burn-core/src/nn/conv/deform_conv2d.rs | 4 +++- crates/burn-core/src/nn/pool/avg_pool1d.rs | 4 +++- crates/burn-core/src/nn/pool/avg_pool2d.rs | 4 +++- crates/burn-core/src/nn/pool/max_pool1d.rs | 4 +++- crates/burn-core/src/nn/pool/max_pool2d.rs | 4 +++- 8 files changed, 24 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index a58fb956d4..c3f61a6b07 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -91,7 +91,9 @@ impl Conv1dConfig { /// Initialize a new [conv1d](Conv1d) module. pub fn init(&self, device: &B::Device) -> Conv1d { checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); - checks::check_same_padding_support(&[self.kernel_size]); + if self.padding == PaddingConfig1d::Same { + checks::check_same_padding_support(&[self.kernel_size]); + } let shape = [ self.channels_out, diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index 9eff037a57..72c00187be 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -72,7 +72,9 @@ impl Conv2dConfig { /// Initialize a new [conv2d](Conv2d) module. pub fn init(&self, device: &B::Device) -> Conv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - checks::check_same_padding_support(&self.kernel_size); + if self.padding == PaddingConfig2d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], diff --git a/crates/burn-core/src/nn/conv/conv3d.rs b/crates/burn-core/src/nn/conv/conv3d.rs index 2d490e2094..0b5d530c5a 100644 --- a/crates/burn-core/src/nn/conv/conv3d.rs +++ b/crates/burn-core/src/nn/conv/conv3d.rs @@ -68,7 +68,9 @@ impl Conv3dConfig { /// Initialize a new [conv3d](Conv3d) module. pub fn init(&self, device: &B::Device) -> Conv3d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - checks::check_same_padding_support(&self.kernel_size); + if self.padding == PaddingConfig3d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], diff --git a/crates/burn-core/src/nn/conv/deform_conv2d.rs b/crates/burn-core/src/nn/conv/deform_conv2d.rs index e8068e5df9..2baff11d07 100644 --- a/crates/burn-core/src/nn/conv/deform_conv2d.rs +++ b/crates/burn-core/src/nn/conv/deform_conv2d.rs @@ -77,7 +77,9 @@ impl DeformConv2dConfig { /// Initialize a new [DeformConv2d](DeformConv2d) module. pub fn init(&self, device: &B::Device) -> DeformConv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups); - checks::check_same_padding_support(&self.kernel_size); + if self.padding == PaddingConfig2d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index 5419d78ef6..24ec8ff972 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -74,7 +74,9 @@ impl ModuleDisplay for AvgPool1d { impl AvgPool1dConfig { /// Initialize a new [avg pool 1d](AvgPool1d) module. pub fn init(&self) -> AvgPool1d { - check_same_padding_support(&[self.kernel_size]); + if self.padding == PaddingConfig1d::Same { + check_same_padding_support(&[self.kernel_size]); + } AvgPool1d { stride: self.stride, kernel_size: self.kernel_size, diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index 2022c87f01..343d59922b 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -74,7 +74,9 @@ impl ModuleDisplay for AvgPool2d { impl AvgPool2dConfig { /// Initialize a new [avg pool 2d](AvgPool2d) module. pub fn init(&self) -> AvgPool2d { - check_same_padding_support(&self.kernel_size); + if self.padding == PaddingConfig2d::Same { + check_same_padding_support(&self.kernel_size); + } AvgPool2d { stride: self.strides, kernel_size: self.kernel_size, diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index df655ac9df..71041e6155 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -66,7 +66,9 @@ impl ModuleDisplay for MaxPool1d { impl MaxPool1dConfig { /// Initialize a new [max pool 1d](MaxPool1d) module. pub fn init(&self) -> MaxPool1d { - check_same_padding_support(&[self.kernel_size]); + if self.padding == PaddingConfig1d::Same { + check_same_padding_support(&[self.kernel_size]); + } MaxPool1d { stride: self.stride, kernel_size: self.kernel_size, diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index d7c2a8a585..3eb94f5db5 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -66,7 +66,9 @@ impl ModuleDisplay for MaxPool2d { impl MaxPool2dConfig { /// Initialize a new [max pool 2d](MaxPool2d) module. pub fn init(&self) -> MaxPool2d { - check_same_padding_support(&self.kernel_size); + if self.padding == PaddingConfig2d::Same { + check_same_padding_support(&self.kernel_size); + } MaxPool2d { stride: self.strides, kernel_size: self.kernel_size,