diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index fb429ffd0f..fc4833c3c4 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -196,6 +196,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` | | `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` | | `tensor.contains_nan()` | N/A | +| `tensor.cumsum(dim)` | `tensor.cumsum(dim)` | | `tensor.div(other)` or `tensor / other` | `tensor / other` | | `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` | | `tensor.equal_elem(other)` | `tensor.eq(other)` | diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 4aad98bb46..ffcc522051 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -127,6 +127,10 @@ impl IntTensorOps for Autodiff { B::int_sum(tensor) } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_cumsum(tensor, dim) + } + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { B::int_sum_dim(tensor, dim) } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 12dc8cf90d..4cc8718bdb 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -1488,6 +1488,38 @@ impl FloatTensorOps for Autodiff } } + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[derive(Debug)] + struct CumSum; + + impl Backward for CumSum { + type State = usize; + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let dim = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let cumsum = B::float_cumsum(grad.clone(), dim); + B::float_flip(cumsum.clone(), &[dim]) + }); + } + } + + match CumSum + .prepare::([tensor.node]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(prep) => prep.finish(dim, B::float_cumsum(tensor.primitive, dim)), + OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)), + } + } + fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct MeanDim; diff --git a/crates/burn-autodiff/src/tests/cumsum.rs b/crates/burn-autodiff/src/tests/cumsum.rs new file mode 100644 index 0000000000..6b9c7749de --- /dev/null +++ b/crates/burn-autodiff/src/tests/cumsum.rs @@ -0,0 +1,22 @@ +#[burn_tensor_testgen::testgen(ad_cumsum)] +mod tests { + use super::*; + use burn_tensor::{loss, Tensor, TensorData}; + + #[test] + fn should_diff_cumsum() { + let device = Default::default(); + let tensor_0 = + TestAutodiffTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device) + .require_grad(); + + let dim = 1; + let tensor_1 = tensor_0.clone().cumsum(dim); + + let grads = tensor_1.backward(); + + let grad_0 = tensor_0.grad(&grads).unwrap(); + let grad_0_expected = TensorData::from([[3., 2., 1.], [3., 2., 1.]]); + grad_0.into_data().assert_approx_eq(&grad_0_expected, 2); + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 438adfb98e..ff5e6d6efc 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -22,6 +22,7 @@ mod conv_transpose2d; mod conv_transpose3d; mod cos; mod cross_entropy; +mod cumsum; mod deform_conv2d; mod div; mod erf; @@ -188,5 +189,6 @@ macro_rules! testgen_with_float_param { burn_autodiff::testgen_ad_expand!(); burn_autodiff::testgen_ad_sort!(); burn_autodiff::testgen_ad_repeat_dim!(); + burn_autodiff::testgen_ad_cumsum!(); }; } diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index 78923fad03..53cc1dc985 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -95,6 +95,7 @@ mod tests { burn_tensor::testgen_round!(); burn_tensor::testgen_floor!(); burn_tensor::testgen_ceil!(); + burn_tensor::testgen_cumsum!(); // TODO: https://github.com/tracel-ai/burn/issues/1237 // @@ -175,4 +176,5 @@ mod tests { burn_autodiff::testgen_ad_round!(); burn_autodiff::testgen_ad_floor!(); burn_autodiff::testgen_ad_ceil!(); + burn_autodiff::testgen_ad_cumsum!(); } diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index bd817a1809..5c4c416b9f 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -145,3 +145,31 @@ pub fn mask_where_broadcasted( CandleTensor::new(mask.tensor.where_cond(&value.tensor, &tensor).unwrap()) } + +// Taken from: https://github.com/mokeyish/candle-ext/blob/main/src/cumsum.rs +fn cumsum_ext( + input: &candle_core::Tensor, + dim: D, +) -> candle_core::Result { + let dim = dim.to_index(input.shape(), "cumsum")?; + let dim_size = input.dim(dim)?; + + let mut tensors = Vec::with_capacity(dim_size); + + let mut a = input.clone(); + for i in 0..dim_size { + if i > 0 { + a = a.narrow(dim, 1, dim_size - i)?; + let b = input.narrow(dim, 0, dim_size - i)?; + a = (a + b)?; + } + tensors.push(a.narrow(dim, 0, 1)?); + } + let cumsum = candle_core::Tensor::cat(&tensors, dim)?; + Ok(cumsum) +} + +/// Cumulative sum (used for int tensors since the default candle implementation uses matmul). +pub fn cumsum(tensor: CandleTensor, dim: usize) -> CandleTensor { + CandleTensor::new(cumsum_ext(&tensor.tensor, dim).unwrap()) +} diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 4ae0c53de7..6e3586506b 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -372,4 +372,8 @@ impl IntTensorOps for Candle) -> IntTensor { sign(tensor) } + + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + super::base::cumsum(tensor, dim) + } } diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index 1457a5a0e6..162b8513c3 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -481,4 +481,8 @@ impl FloatTensorOps for Candle CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap()) } } + + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { + CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) + } } diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 57cfbd4132..41d691272f 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -2263,4 +2263,31 @@ impl FloatTensorOps for Fusion { out } + + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(CumsumOps, B::float_cumsum, usize, noconvert); + + let stream = tensor.stream; + let dtype = tensor.dtype; + let shape = tensor.shape.clone(); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::CumSum(desc.clone()), + ), + CumsumOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index bdb47df02c..d3a2492a03 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -1819,4 +1819,31 @@ impl IntTensorOps for Fusion { out } + + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(CumsumOps, B::int_cumsum, usize, noconvert); + + let stream = tensor.stream; + let dtype = tensor.dtype; + let shape = tensor.shape.clone(); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::NumericInt( + dtype, + NumericOperationDescription::CumSum(desc.clone()), + ), + CumsumOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index d5e1ee9e38..477cf737f2 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -961,6 +961,13 @@ impl RelativeOpsScalar for NumericOperationDescription { out: desc.out.to_relative(converter), }) } + NumericOperationDescription::CumSum(desc) => { + NumericOperationDescription::CumSum(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 2dc8a4a6f2..8cf10292d7 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -665,4 +665,8 @@ where _ => unimplemented!("Unsupported floating point type cast"), } } + + fn float_cumsum(_tensor: FloatTensor, _dim: usize) -> FloatTensor { + todo!() + } } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 25bb92521f..b0772f11cc 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -283,4 +283,8 @@ where fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { kernel::flip::(tensor, axes) } + + fn int_cumsum(_tensor: IntTensor, _dim: usize) -> IntTensor { + todo!() + } } diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index b364b7bd20..8a5ebea7b8 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -262,6 +262,16 @@ where NdArrayTensor::from_data(data) } + pub fn cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + let mut array = tensor.array.into_owned(); + array.accumulate_axis_inplace(Axis(dim), |&prev, curr| { + *curr += prev; + }); + let array = array.into_shared(); + + NdArrayTensor { array } + } + pub fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { let ndims = tensor.shape().num_dims(); match ndims { diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 9009b5c4a8..87d6e46a3f 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -351,4 +351,8 @@ impl IntTensorOps fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::expand(tensor, shape) } + + fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::cumsum(tensor, dim) + } } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 97085f6ce4..4b34959944 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -575,4 +575,8 @@ impl FloatTensorO _ => panic!("Invalid cast types"), } } + + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim)) + } } diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index dda01990e0..e66343ec2a 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -1491,4 +1491,23 @@ impl FloatTensorOps for BackendRouter { out } + + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::CumSum(desc), + )); + + out + } } diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index db81602d4f..c0d3c13218 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -1173,4 +1173,23 @@ impl IntTensorOps for BackendRouter { out } + + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::CumSum(desc), + )); + + out + } } diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 04f93a4769..0c9238e5e7 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -573,6 +573,9 @@ impl RunnerClient for Runner { NumericOperationDescription::Powf(desc) => { binary_float_ops!(handles, desc, B::float_powf) } + NumericOperationDescription::CumSum(desc) => { + scalar_float_dim_ops!(handles, desc, B::float_cumsum) + } }, OperationDescription::NumericInt(_dtype, op) => match op { NumericOperationDescription::Add(desc) => { @@ -764,6 +767,9 @@ impl RunnerClient for Runner { let output = B::int_powf(lhs, rhs); handles.register_int_tensor::(&desc.out.id, output); } + NumericOperationDescription::CumSum(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_cumsum) + } }, OperationDescription::Bool(op) => match op { BoolOperationDescription::IntoFloat(desc) => { diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 7b04207871..e1a7bf7d09 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -299,6 +299,13 @@ impl TchOps { TchTensor::new(tensor) } + pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing( + tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()), + tensor.storage, + ) + } + pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchTensor::from_existing( tensor diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 0da31fe430..b2cd14f326 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -416,4 +416,8 @@ impl IntTensorOps for LibTorch { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { TchOps::argsort(tensor, dim, descending) } + + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + TchOps::cumsum(tensor, dim) + } } diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 22d63b9d64..fdca7f3116 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -479,4 +479,8 @@ impl FloatTensorOps for LibTorch { TchTensor::new(tensor.tensor.to_kind(kind)) } } + + fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cumsum(tensor, dim) + } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 001b9d6e83..75b46f2220 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -401,19 +401,21 @@ pub enum NumericOperationDescription { /// Float => [sum dim](crate::ops::FloatTensorOps::float_sum_dim). /// Int => [sum dim](crate::ops::IntTensorOps::int_sum_dim). SumDim(ScalarOperationDescription), - + /// Operation corresponding to: + /// + /// Float => [cumsum](crate::ops::FloatTensorOps::float_cumsum). + /// Int => [cumsum](crate::ops::IntTensorOps::int_cumsum). + CumSum(ScalarOperationDescription), /// Operation corresponding to: /// /// Float => [prod](crate::ops::FloatTensorOps::float_prod). /// Int => [prod](crate::ops::IntTensorOps::int_prod). Prod(UnaryOperationDescription), - /// Operation corresponding to: /// /// Float => [prod dim](crate::ops::FloatTensorOps::float_prod_dim). /// Int => [prod dim](crate::ops::IntTensorOps::int_prod_dim). ProdDim(ScalarOperationDescription), - /// Operation corresponding to: /// /// Float => [equal elem](crate::ops::FloatTensorOps::float_equal_elem). @@ -1503,6 +1505,7 @@ impl NumericOperationDescription { NumericOperationDescription::Powf(desc) => { vec![&desc.lhs, &desc.rhs, &desc.out] } + NumericOperationDescription::CumSum(desc) => vec![&desc.lhs, &desc.out], } } } @@ -1761,6 +1764,7 @@ impl core::hash::Hash for NumericOperationDescription { NumericOperationDescription::Clamp(desc) => desc.hash(state), NumericOperationDescription::IntRandom(desc) => desc.hash(state), NumericOperationDescription::Powf(desc) => desc.hash(state), + NumericOperationDescription::CumSum(desc) => desc.hash(state), } } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 59dc44b7e6..54d801c125 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -486,6 +486,13 @@ where Tensor::new(K::sum(self.primitive)) } + /// Aggregate all elements along the given *dimension* or *axis* with the + /// cumulative sum operation. + pub fn cumsum(self, dim: usize) -> Tensor { + check!(TensorCheck::aggregate_dim::("CumSum", dim)); + Tensor::new(K::cumsum(self.primitive, dim)) + } + /// Aggregate all elements along the given *dimension* or *axis* /// in the tensor with the mean operation. /// @@ -2464,6 +2471,27 @@ where /// which is more high-level and designed for public use. fn sum(tensor: Self::Primitive) -> Self::Primitive; + /// Computes the cumulative sum of all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// The cumulative sum of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::cumsum](Tensor::cumsum) function, + /// which is more high-level and designed for public use. + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Sums all the elements of the tensor along a dimension. /// /// # Arguments @@ -3611,6 +3639,10 @@ impl Numeric for Int { ) -> >::Primitive { B::int_argsort(tensor, dim, descending) } + + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_cumsum(tensor, dim) + } } impl Numeric for Float { @@ -4125,6 +4157,13 @@ impl Numeric for Float { TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending), } } + + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_cumsum(tensor, dim)), + } + } } impl core::ops::Add for Tensor diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index abdd2e54ba..89cbf74ebc 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -671,6 +671,18 @@ pub trait IntTensorOps { /// The sum of all elements in the tensor. fn int_sum(tensor: IntTensor) -> IntTensor; + /// Computes the cumulative sum of all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension to sum along. + /// + /// # Returns + /// + /// The cumulative sum of all elements in the tensor along the dimension. + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor; + /// Sums all elements in the tensor along a dimension. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 781ed7c6eb..0d951c9bb4 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -681,6 +681,24 @@ pub trait QTensorOps { ) } + /// Computes the cumulative sum of all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension to sum along. + /// + /// # Returns + /// + /// The cumulative sum of all elements in the tensor along the dimension. + fn q_cumsum(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_cumsum(tensor, dim), + tensor + ) + } + /// Sum of all elements in a tensor along a dimension. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index a3e7500419..11d1015f6c 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -702,6 +702,18 @@ pub trait FloatTensorOps { /// A scalar tensor with the sum of all elements in `tensor`. fn float_sum(tensor: FloatTensor) -> FloatTensor; + /// Computes the cumulative sum of all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension to sum along. + /// + /// # Returns + /// + /// The cumulative sum of all elements in the tensor along the dimension. + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor; + /// Sum of all elements in a tensor along a dimension. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 8aa41ee24d..2ce0a2cc1d 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -164,6 +164,7 @@ macro_rules! testgen_quantization { burn_tensor::testgen_q_tanh!(); burn_tensor::testgen_q_topk!(); burn_tensor::testgen_q_transpose!(); + burn_tensor::testgen_q_cumsum!(); }; } @@ -273,6 +274,7 @@ macro_rules! testgen_with_float_param { burn_tensor::testgen_select!(); burn_tensor::testgen_split!(); burn_tensor::testgen_prod!(); + burn_tensor::testgen_cumsum!(); // test stats burn_tensor::testgen_var!(); diff --git a/crates/burn-tensor/src/tests/ops/cumsum.rs b/crates/burn-tensor/src/tests/ops/cumsum.rs new file mode 100644 index 0000000000..5f771abede --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/cumsum.rs @@ -0,0 +1,39 @@ +#[burn_tensor_testgen::testgen(cumsum)] +mod tests { + use super::*; + use burn_tensor::{backend::Backend, Int, Tensor, TensorData}; + + #[test] + fn should_support_cumsum_ops() { + let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let device = Default::default(); + let tensor = TestTensor::<2>::from_data(data, &device); + + let output = tensor.clone().cumsum(0); + let expected = TensorData::from([[0.0, 1.0, 2.0], [3.0, 5.0, 7.0]]); + + output.into_data().assert_eq(&expected, false); + + let output = tensor.cumsum(1); + let expected = TensorData::from([[0.0, 1.0, 3.0], [3.0, 7.0, 12.0]]); + + output.into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_cumsum_ops_int() { + let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); + let device = Default::default(); + let tensor = TestTensorInt::<2>::from_data(data, &device); + + let output = tensor.clone().cumsum(0); + let expected = TensorData::from([[0, 1, 2], [3, 5, 7]]); + + output.into_data().assert_eq(&expected, false); + + let output = tensor.cumsum(1); + let expected = TensorData::from([[0, 1, 3], [3, 7, 12]]); + + output.into_data().assert_eq(&expected, false); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index b1096e0216..e101dc7777 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -17,6 +17,7 @@ mod clamp; mod close; mod cos; mod create_like; +mod cumsum; mod div; mod erf; mod exp; diff --git a/crates/burn-tensor/src/tests/quantization/ops/cumsum.rs b/crates/burn-tensor/src/tests/quantization/ops/cumsum.rs new file mode 100644 index 0000000000..c95fa774cd --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/cumsum.rs @@ -0,0 +1,19 @@ +#[burn_tensor_testgen::testgen(q_cumsum)] +mod tests { + use super::*; + use burn_tensor::TensorData; + + #[test] + fn test_should_support_cumsum_ops() { + let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + + let output = tensor.cumsum(0); + let expected = TensorData::from([[0.0, 1.0, 2.0], [3.0, 5.0, 7.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/mod.rs b/crates/burn-tensor/src/tests/quantization/ops/mod.rs index 083a1b871b..39ddce0336 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/mod.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/mod.rs @@ -9,6 +9,7 @@ mod ceil; mod chunk; mod clamp; mod cos; +mod cumsum; mod div; mod erf; mod exp;