Skip to content

Commit

Permalink
Softmin (#2358)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahSchiro authored Oct 15, 2024
1 parent 8f8cd37 commit 3d77efc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 0 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ strategies.
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
| `activation::softmin(tensor, dim)` | `nn.functional.softmin(tensor, dim)` |
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |

Expand Down
13 changes: 13 additions & 0 deletions crates/burn-tensor/src/tensor/activation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) ->
tensor.div(tensor_tmp)
}

/// Applies the softmin function on the input tensor along the given dimension.
///
/// `softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j))`
///
/// # Notes
///
/// The dimension argument `dim` specifies the dimension along which the function will be computed.
/// It must in the range of `0` and `D-1`.
pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmin", dim));
softmax(tensor.neg(), dim)
}

/// Applies the softplus function
///
/// `softplus(x_i) = log(1 + exp(\beta x_i)) / \beta`
Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/activation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ pub(crate) mod relu;
pub(crate) mod sigmoid;
pub(crate) mod silu;
pub(crate) mod softmax;
pub(crate) mod softmin;
pub(crate) mod softplus;
pub(crate) mod tanh_activation;
15 changes: 15 additions & 0 deletions crates/burn-tensor/src/tests/activation/softmin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#[burn_tensor_testgen::testgen(softmin)]
mod tests {
use super::*;
use burn_tensor::{activation, Tensor, TensorData};

#[test]
fn test_softmin_d2() {
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);

let output = activation::softmin(tensor, 1);
let expected = TensorData::from([[9.9753e-01, 2.4726e-03], [1.1254e-07, 1.0000e+00]]);

output.into_data().assert_approx_eq(&expected, 4);
}
}
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_relu!();
burn_tensor::testgen_leaky_relu!();
burn_tensor::testgen_softmax!();
burn_tensor::testgen_softmin!();
burn_tensor::testgen_softplus!();
burn_tensor::testgen_sigmoid!();
burn_tensor::testgen_log_sigmoid!();
Expand Down

0 comments on commit 3d77efc

Please sign in to comment.