Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add new one hot function meeting multi-dimensions (ranks) #2613

Merged
merged 21 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.one_hot_fill(depth, on_value, off_value, axis)` | N/A |
laggui marked this conversation as resolved.
Show resolved Hide resolved
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
Expand Down Expand Up @@ -258,7 +259,7 @@ Those operations are only available for `Float` tensors.

| Burn API | PyTorch Equivalent |
| --------------------------------------------- | ---------------------------------- |
| `Tensor::one_hot(index, num_classes, device)` | N/A |
| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` |
laggui marked this conversation as resolved.
Show resolved Hide resolved
| `tensor.cast(dtype)` | `tensor.to(dtype)` |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
Expand Down Expand Up @@ -296,7 +297,7 @@ Those operations are only available for `Int` tensors.
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
| `tensor.cartesian_grid(shape, device)` | N/A |
| `tensor.one_hot(num_classes)` | N/A |
| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` |

### Bool Operations

Expand Down
20 changes: 3 additions & 17 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{backend::Backend, BasicOps, Int, Shape, Tensor};
use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor};
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
Expand Down Expand Up @@ -447,22 +447,8 @@ impl TensorCheck {
check
}

pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self {
let mut check = Self::Ok;
if index >= num_classes {
check = check.register(
"One Hot",
TensorError::new(format!(
"Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})",
)),
);
}

check
}

pub(crate) fn one_hot_tensor<B: Backend>(
index_tensor: Tensor<B, 1, Int>,
pub(crate) fn one_hot_tensor<B: Backend, const D: usize, K: Numeric<B>>(
index_tensor: Tensor<B, D, K>,
laggui marked this conversation as resolved.
Show resolved Hide resolved
num_classes: usize,
) -> Self {
let mut check = Self::Ok;
Expand Down
27 changes: 8 additions & 19 deletions crates/burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use alloc::vec::Vec;
use core::convert::TryInto;

use crate::check::TensorCheck;
use crate::quantization::{QuantizationParameters, QuantizationScheme};
use crate::tensor::backend::Backend;
use crate::tensor::stats;
use crate::tensor::{Distribution, Shape, TensorData};
use crate::tensor::{Distribution, TensorData};
use crate::Tensor;
use crate::{check, FloatDType};
use crate::{Int, TensorPrimitive};
Expand Down Expand Up @@ -182,25 +179,17 @@ where
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// fn example<B: Backend>(){
/// let device = Default::default();
/// let one_hot = Tensor::<B, 1>::one_hot(2, 10, &device);
/// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
/// let one_hot: Tensor<B, 4> = indices.one_hot(4);
/// println!("{}", one_hot.to_data());
/// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
/// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
/// }
/// ```
pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self {
check!(TensorCheck::one_hot_index(index, num_classes));

let mut dims = [1; D];
dims[D - 1] = num_classes;
let shape = Shape::new(dims);
let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect();
let tensor = Tensor::zeros(shape, device);
let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
ranges[D - 1] = index..index + 1;

tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device))
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2> {
laggui marked this conversation as resolved.
Show resolved Hide resolved
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
self.one_hot_fill(num_classes, 1.0, 0.0, -1)
}
laggui marked this conversation as resolved.
Show resolved Hide resolved

/// Applies the matrix multiplication operation.
Expand Down
53 changes: 25 additions & 28 deletions crates/burn-tensor/src/tensor/api/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,6 @@ where
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::int_arange_step(range, step, device))
}

/// Create a one hot tensor from an index tensor.
///
/// # Arguments
///
/// * `num_classes` - The number of classes to use in encoding.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let indices: Tensor<B, 1, Int> = Tensor::from_ints([0, 1, 2, 3], &device);
/// let one_hot = indices.one_hot(4);
/// println!("{}", one_hot.to_data());
/// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
/// }
/// ```
pub fn one_hot(self, num_classes: usize) -> Tensor<B, 2, Int> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
let [num_samples] = self.dims();
let indices = self.unsqueeze_dim(1);
let values = indices.ones_like();
Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values)
}
}

impl<const D: usize, B> Tensor<B, D, Int>
Expand Down Expand Up @@ -129,4 +101,29 @@ where
) -> Tensor<B, D2, Int> {
cartesian_grid::<B, S, D, D2>(shape, device)
}

/// Create a one hot tensor from an index tensor.
///
/// # Arguments
///
/// * `num_classes` - The number of classes to use in encoding.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>(){
/// let device = B::Device::default();
/// let indices: Tensor<B, 1, Int> = Tensor::from_ints([0, 1, 2, 3], &device);
/// let one_hot: Tensor<B, 4, Int> = indices.one_hot(4);
/// println!("{}", one_hot.to_data());
/// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
/// }
/// ```
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, Int> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
self.one_hot_fill(num_classes, 1.0, 0.0, -1)
}
laggui marked this conversation as resolved.
Show resolved Hide resolved
}
80 changes: 80 additions & 0 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,86 @@ where
padded_tensor.slice_assign(ranges, self)
}

/// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
///
/// # Arguments
///
/// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
/// * `on_value`: The value to assign for active positions (corresponding to indices).
/// * `off_value`: The value to assign for inactive positions.
/// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
///
/// # Returns
///
/// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Float};
/// fn example<B: Backend<FloatElem: From<f32>>>() {
/// let device = B::Device::default();
/// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
/// // One-hot encoding
/// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
/// println!("{tensor}");
/// // [[[5.0, 0.0, 0.0],
/// // [0.0, 0.0, 5.0]],
/// // [[0.0, 5.0, 0.0],
/// // [0.0, 0.0, 5.0]]]
/// }
/// ```
pub fn one_hot_fill<K2: Numeric<B>, const D2: usize>(
self,
num_classes: usize,
on_value: f32,
off_value: f32,
axis: i64,
) -> Tensor<B, D2, K2> {
laggui marked this conversation as resolved.
Show resolved Hide resolved
// Initialize shape from the current tensor dimensions and prepare for modification
let mut shape = self.shape().dims::<D>().to_vec();
let device = self.device();
let rank = self.dims().len();

// Adjust negative axis to a positive index
let axis = if axis < 0 {
axis + rank as i64 + 1
} else {
axis
};

// Ensure axis is within valid range
if axis < 0 || axis > rank as i64 {
panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
}
// Convert the input tensor to integer indices
let indices: Tensor<B, D, Int> =
Tensor::from_data(self.to_data().convert::<i64>(), &device);
// Insert the new dimension for the one-hot representation
shape.insert(axis as usize, num_classes);
// Adjust indices to valid range and handle invalid indices
let adjusted_indices = indices
.clone()
.mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
.add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
check!(TensorCheck::one_hot_tensor(
adjusted_indices.clone(),
num_classes
));
// Unsqueeze the indices tensor along the specified axis
let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);

// Initialize the output tensor with the off_value
let output = Tensor::full(shape.clone(), off_value, &device);

// Prepare scatter tensor for on_value and off_value adjustments
let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
- Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());

// Scatter on_value at the appropriate indices to create the one-hot representation
output.scatter(axis as usize, indices_unsqueezed, scatter_on_values)
}

/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
///
/// # Returns
Expand Down
106 changes: 69 additions & 37 deletions crates/burn-tensor/src/tests/ops/one_hot.rs
Original file line number Diff line number Diff line change
@@ -1,74 +1,106 @@
#[burn_tensor_testgen::testgen(one_hot)]
mod tests {
use super::*;
use burn_tensor::{Int, TensorData};
use burn_tensor::{
as_type,
backend::Backend,
tests::{Float as _, Int as _},
Float, Int, Numeric, Shape, Tensor, TensorData,
};

#[test]
fn float_should_support_one_hot() {
let device = Default::default();

let tensor = TestTensor::<1>::one_hot(0, 5, &device);
let expected = TensorData::from([1., 0., 0., 0., 0.]);
tensor.into_data().assert_eq(&expected, false);

let tensor = TestTensor::<1>::one_hot(1, 5, &device);
let expected = TensorData::from([0., 1., 0., 0., 0.]);
tensor.into_data().assert_eq(&expected, false);

let tensor = TestTensor::<1>::one_hot(4, 5, &device);
let expected = TensorData::from([0., 0., 0., 0., 1.]);
tensor.into_data().assert_eq(&expected, false);

let tensor = TestTensor::<1>::one_hot(1, 2, &device);
let expected = TensorData::from([0., 1.]);
tensor.into_data().assert_eq(&expected, false);
let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]);
let one_hot_tensor: Tensor<TestBackend, 2, Float> = tensor.one_hot(5);
let expected = TensorData::from([
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
]);
one_hot_tensor.into_data().assert_eq(&expected, false);
}

#[test]
#[should_panic]
fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() {
let device = Default::default();
let tensor = TestTensor::<1>::one_hot(1, 1, &device);
let tensor = TestTensor::<1>::from([5.0]);
let result: Tensor<TestBackend, 2> = tensor.one_hot(5);
}

#[test]
#[should_panic]
fn float_one_hot_should_panic_when_number_of_classes_is_zero() {
let device = Default::default();
let tensor = TestTensor::<1>::one_hot(0, 0, &device);
let tensor = TestTensor::<1>::from([0.0]);
let result: Tensor<TestBackend, 2> = tensor.one_hot(0);
}

#[test]
fn int_should_support_one_hot() {
let device = Default::default();

let index_tensor = TestTensorInt::<1>::arange(0..5, &device);
let one_hot_tensor = index_tensor.one_hot(5);
let expected = TestTensorInt::eye(5, &device).into_data();
let tensor = TestTensorInt::<1>::from([0, 1, 4]);
let one_hot_tensor: Tensor<TestBackend, 2, Int> = tensor.one_hot(5);
let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]);
one_hot_tensor.into_data().assert_eq(&expected, false);
}

#[test]
#[should_panic]
fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() {
let device = Default::default();
let index_tensor = TestTensorInt::<1>::arange(0..6, &device);
let one_hot_tensor = index_tensor.one_hot(5);
let tensor = TestTensorInt::<1>::from([5]);
let result: Tensor<TestBackend, 2, Int> = tensor.one_hot(5);
}

#[test]
#[should_panic]
fn int_one_hot_should_panic_when_number_of_classes_is_zero() {
let device = Default::default();
let index_tensor = TestTensorInt::<1>::arange(0..3, &device);
let one_hot_tensor = index_tensor.one_hot(0);
let tensor = TestTensorInt::<1>::from([2]);
let result: Tensor<TestBackend, 2, Int> = tensor.one_hot(0);
}

#[test]
fn one_hot_fill_with_positive_axis_and_indices() {
let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]);
let expected = TensorData::from(as_type!(IntType: [
[[1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3]],
[[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]
]));

let one_hot_tensor: Tensor<TestBackend, 3, Int> = tensor.one_hot_fill(10, 3.0, 1.0, 1);

one_hot_tensor.into_data().assert_eq(&expected, true);
}

#[test]
fn one_hot_fill_with_negative_axis_and_indices() {
let tensor = TestTensorInt::<2>::from([[0, 2], [1, -1]]);
laggui marked this conversation as resolved.
Show resolved Hide resolved
let expected = TensorData::from(as_type!(FloatType: [
[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]],
[[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]
]));

let one_hot_tensor: Tensor<TestBackend, 3> = tensor.one_hot_fill(3, 5.0, 0.0, -1);

one_hot_tensor.into_data().assert_eq(&expected, true);
}

#[test]
fn one_hot_fill_with_negative_indices() {
let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]);
let expected = TensorData::from(as_type!(FloatType: [
[3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
]));

let one_hot_tensor: Tensor<TestBackend, 2> = tensor.one_hot_fill(10, 3.0, 1.0, 1);

one_hot_tensor.into_data().assert_eq(&expected, true);
}

#[should_panic]
fn int_one_hot_should_panic_when_number_of_classes_is_1() {
let device = Default::default();
let index_tensor = TestTensorInt::<1>::arange(0..3, &device);
let one_hot_tensor = index_tensor.one_hot(1);
#[test]
fn one_hot_fill_should_panic_when_axis_out_range_of_rank() {
let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]);

let one_hot_tensor: Tensor<TestBackend, 3, Float> = tensor.one_hot_fill(2, 5.0, 0.0, 3);
}
}
Loading