Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumulative sum tensor operation #1722

Closed
wants to merge 11 commits into from
7 changes: 7 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn int_cumsum_dim<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn float_cumsum_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}

fn float_mean_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
Expand Down
13 changes: 12 additions & 1 deletion crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_tensor::{Distribution, Reader};

use burn_tensor::ElementConversion;
use core::ops::Range;
use ndarray::IntoDimension;
use ndarray::{Axis, IntoDimension};

// Current crate
use crate::element::ExpElement;
Expand Down Expand Up @@ -286,6 +286,17 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
NdArrayMathOps::sum_dim(tensor, dim)
}

fn int_cumsum_dim<const D: usize>(
tensor: NdArrayTensor<i64, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
let mut array = tensor.array.clone().into_owned();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment for float_cumsum


array.accumulate_axis_inplace(Axis(dim), |&prev, curr| *curr += prev);

NdArrayTensor::new(array.to_shared())
}

fn int_prod<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, 1> {
NdArrayMathOps::prod(tensor)
}
Expand Down
13 changes: 12 additions & 1 deletion crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Language
use alloc::vec::Vec;
use core::ops::Range;
use ndarray::IntoDimension;
use ndarray::{Axis, IntoDimension};

// Current crate
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
Expand Down Expand Up @@ -338,6 +338,17 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
NdArrayMathOps::sum_dim(tensor, dim)
}

fn float_cumsum_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<E, D> {
let mut array = tensor.array.clone().into_owned();
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the underlying array struct of tensor needs to be cloned, since NdArray's method for accumulating elements along an axis modifies an array's data inplace. Referring to this method

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well float_cumsum takes ownership of the tensor, so I don't think the clone is required here.


array.accumulate_axis_inplace(Axis(dim), |&prev, curr| *curr += prev);

NdArrayTensor::new(array.to_shared())
}

fn float_argmax<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
)
}

pub fn cumsum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchTensor::from_existing(tensor.tensor.cumsum(dim as i64, E::KIND), tensor.storage)
}

pub fn prod<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
let tensor = tensor.tensor.prod(E::KIND);
TchTensor::new(tensor)
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
TchOps::sum_dim(tensor, dim)
}

fn int_cumsum_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::cumsum_dim(tensor, dim)
}

fn int_prod<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
TchOps::prod(tensor)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
TchOps::sum_dim(tensor, dim)
}

fn float_cumsum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::cumsum_dim(tensor, dim)
}

fn float_mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::mean_dim(tensor, dim)
}
Expand Down
16 changes: 16 additions & 0 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,22 @@ impl TensorCheck {
check
}

/// Checks running dimension such as cumulative sum
pub(crate) fn running_dim<const D: usize>(ops: &str, dim: usize) -> Self {
let mut check = Self::Ok;

if dim > D {
check = check.register(
ops,
TensorError::new(format!(
"Can't perform a running calculation on a tensor with ({D}) dimensions on axis ({dim})"
)),
);
}

check
}

Comment on lines +806 to +821
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use the existing TensorCheck::dim_ops instead

pub(crate) fn sort_dim<const D: usize>(ops: &str, dim: usize) -> Self {
let mut check = Self::Ok;

Expand Down
36 changes: 36 additions & 0 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ where
Self::new(K::sum_dim(self.primitive, dim))
}

/// Perform a cumulative sum on all elements along the given *dimension* or *axis*
/// in the tensor with the sum operation.
pub fn cumsum_dim(self, dim: usize) -> Self {
check!(TensorCheck::running_dim::<D>("Sum", dim));
Self::new(K::cumsum_dim(self.primitive, dim))
}

/// Aggregate all elements along the given *dimension* or *axis*
/// in the tensor with the product operation.
pub fn prod(self) -> Tensor<B, 1, K> {
Expand Down Expand Up @@ -1173,6 +1180,27 @@ where
/// which is more high-level and designed for public use.
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;

/// Performs cumulative sum across all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to perform cumulative sum on.
/// * `dim` - The dimension along which to perform cumulative sum.
///
/// # Returns
///
/// The cumulative sum of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For performing cumulative sum across all the elements of a tensor along a dimension, users should prefer the [Tensor::cumsum_dim](Tensor::cumsum_dim) function,
/// which is more high-level and designed for public use.
fn cumsum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;

/// Computes the product of all the elements of the tensor.
///
/// # Arguments
Expand Down Expand Up @@ -2176,6 +2204,10 @@ impl<B: Backend> Numeric<B> for Int {
B::int_sum_dim(tensor, dim)
}

fn cumsum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::int_cumsum_dim(tensor, dim)
}

fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::int_prod(tensor)
}
Expand Down Expand Up @@ -2521,6 +2553,10 @@ impl<B: Backend> Numeric<B> for Float {
B::float_sum_dim(tensor, dim)
}

fn cumsum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::float_cumsum_dim(tensor, dim)
}

fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::float_prod(tensor)
}
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tensor/src/tensor/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,18 @@ pub trait IntTensorOps<B: Backend> {
/// The sum of all elements in the tensor along the dimension.
fn int_sum_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D>;

/// Cumulative Sum of all elements in a tensor along a dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the capitalization at "Cumulative sum"

///
/// # Arguments
///
/// * `tensor` - The tensor to perform cumulative sum on.
/// * `dim` - The dimension along which to perform cumulative sum.
///
/// # Returns
///
/// A tensor with the cumulative sum of all elements in `tensor` along `dim`.
fn int_cumsum_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D>;

/// Computes the product of all elements in the tensor.
///
/// # Arguments
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-tensor/src/tensor/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,19 @@ pub trait FloatTensorOps<B: Backend> {
/// A tensor with the sum of all elements in `tensor` along `dim`.
fn float_sum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;

/// Cumulative Sum of all elements in a tensor along a dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing regarding capitalization

///
/// # Arguments
///
/// * `tensor` - The tensor to perform cumulative sum on.
/// * `dim` - The dimension along which to perform cumulative sum.
///
/// # Returns
///
/// A tensor with the cumulative sum of all elements in `tensor` along `dim`.
fn float_cumsum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize)
-> FloatTensor<B, D>;

/// Product of all elements in a tensor.
///
/// # Arguments
Expand Down