Skip to content

Commit

Permalink
Check only for same padding (doh)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 10, 2025
1 parent 6efdd91 commit e142995
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 8 deletions.
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/conv/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ impl Conv1dConfig {
/// Initialize a new [conv1d](Conv1d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
checks::check_same_padding_support(&[self.kernel_size]);
if self.padding == PaddingConfig1d::Same {
checks::check_same_padding_support(&[self.kernel_size]);
}

let shape = [
self.channels_out,
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/conv/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ impl Conv2dConfig {
/// Initialize a new [conv2d](Conv2d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
checks::check_same_padding_support(&self.kernel_size);
if self.padding == PaddingConfig2d::Same {
checks::check_same_padding_support(&self.kernel_size);
}

let shape = [
self.channels[1],
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/conv/conv3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ impl Conv3dConfig {
/// Initialize a new [conv3d](Conv3d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv3d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
checks::check_same_padding_support(&self.kernel_size);
if self.padding == PaddingConfig3d::Same {
checks::check_same_padding_support(&self.kernel_size);
}

let shape = [
self.channels[1],
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/conv/deform_conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ impl DeformConv2dConfig {
/// Initialize a new [DeformConv2d](DeformConv2d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> DeformConv2d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);
checks::check_same_padding_support(&self.kernel_size);
if self.padding == PaddingConfig2d::Same {
checks::check_same_padding_support(&self.kernel_size);
}

let shape = [
self.channels[1],
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/pool/avg_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ impl ModuleDisplay for AvgPool1d {
impl AvgPool1dConfig {
/// Initialize a new [avg pool 1d](AvgPool1d) module.
pub fn init(&self) -> AvgPool1d {
check_same_padding_support(&[self.kernel_size]);
if self.padding == PaddingConfig1d::Same {
check_same_padding_support(&[self.kernel_size]);
}
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/pool/avg_pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ impl ModuleDisplay for AvgPool2d {
impl AvgPool2dConfig {
/// Initialize a new [avg pool 2d](AvgPool2d) module.
pub fn init(&self) -> AvgPool2d {
check_same_padding_support(&self.kernel_size);
if self.padding == PaddingConfig2d::Same {
check_same_padding_support(&self.kernel_size);
}
AvgPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/pool/max_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ impl ModuleDisplay for MaxPool1d {
impl MaxPool1dConfig {
/// Initialize a new [max pool 1d](MaxPool1d) module.
pub fn init(&self) -> MaxPool1d {
check_same_padding_support(&[self.kernel_size]);
if self.padding == PaddingConfig1d::Same {
check_same_padding_support(&[self.kernel_size]);
}
MaxPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/nn/pool/max_pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ impl ModuleDisplay for MaxPool2d {
impl MaxPool2dConfig {
/// Initialize a new [max pool 2d](MaxPool2d) module.
pub fn init(&self) -> MaxPool2d {
check_same_padding_support(&self.kernel_size);
if self.padding == PaddingConfig2d::Same {
check_same_padding_support(&self.kernel_size);
}
MaxPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
Expand Down

0 comments on commit e142995

Please sign in to comment.