Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumsum tensor op #2664

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.contains_nan()` | N/A |
| `tensor.cumsum(dim)` | `tensor.cumsum(dim)` |
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_sum(tensor)
}

fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cumsum(tensor, dim)
}

fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_sum_dim(tensor, dim)
}
Expand Down
32 changes: 32 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,38 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
#[derive(Debug)]
struct CumSum;

impl<B: Backend> Backward<B, 1> for CumSum {
type State = usize;

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let dim = ops.state;

unary::<B, _>(ops.parents, ops.node, grads, |grad| {
let cumsum = B::float_cumsum(grad.clone(), dim);
B::float_flip(cumsum.clone(), &[dim])
});
}
}

match CumSum
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(dim, B::float_cumsum(tensor.primitive, dim)),
OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)),
}
}

fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
#[derive(Debug)]
struct MeanDim;
Expand Down
22 changes: 22 additions & 0 deletions crates/burn-autodiff/src/tests/cumsum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#[burn_tensor_testgen::testgen(ad_cumsum)]
mod tests {
use super::*;
use burn_tensor::{loss, Tensor, TensorData};

#[test]
fn should_diff_cumsum() {
let device = Default::default();
let tensor_0 =
TestAutodiffTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device)
.require_grad();

let dim = 1;
let tensor_1 = tensor_0.clone().cumsum(dim);

let grads = tensor_1.backward();

let grad_0 = tensor_0.grad(&grads).unwrap();
let grad_0_expected = TensorData::from([[3., 2., 1.], [3., 2., 1.]]);
grad_0.into_data().assert_approx_eq(&grad_0_expected, 2);
}
}
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod conv_transpose2d;
mod conv_transpose3d;
mod cos;
mod cross_entropy;
mod cumsum;
mod deform_conv2d;
mod div;
mod erf;
Expand Down Expand Up @@ -188,5 +189,6 @@ macro_rules! testgen_with_float_param {
burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!();
burn_autodiff::testgen_ad_repeat_dim!();
burn_autodiff::testgen_ad_cumsum!();
};
}
2 changes: 2 additions & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ mod tests {
burn_tensor::testgen_round!();
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
burn_tensor::testgen_cumsum!();

// TODO: https://github.com/tracel-ai/burn/issues/1237
//
Expand Down Expand Up @@ -175,4 +176,5 @@ mod tests {
burn_autodiff::testgen_ad_round!();
burn_autodiff::testgen_ad_floor!();
burn_autodiff::testgen_ad_ceil!();
burn_autodiff::testgen_ad_cumsum!();
}
28 changes: 28 additions & 0 deletions crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,31 @@ pub fn mask_where_broadcasted(

CandleTensor::new(mask.tensor.where_cond(&value.tensor, &tensor).unwrap())
}

// Taken from: https://github.com/mokeyish/candle-ext/blob/main/src/cumsum.rs
fn cumsum_ext<D: candle_core::shape::Dim>(
input: &candle_core::Tensor,
dim: D,
) -> candle_core::Result<candle_core::Tensor> {
let dim = dim.to_index(input.shape(), "cumsum")?;
let dim_size = input.dim(dim)?;

let mut tensors = Vec::with_capacity(dim_size);

let mut a = input.clone();
for i in 0..dim_size {
if i > 0 {
a = a.narrow(dim, 1, dim_size - i)?;
let b = input.narrow(dim, 0, dim_size - i)?;
a = (a + b)?;
}
tensors.push(a.narrow(dim, 0, 1)?);
}
let cumsum = candle_core::Tensor::cat(&tensors, dim)?;
Ok(cumsum)
}

/// Cumulative sum (used for int tensors since the default candle implementation uses matmul).
pub fn cumsum(tensor: CandleTensor, dim: usize) -> CandleTensor {
CandleTensor::new(cumsum_ext(&tensor.tensor, dim).unwrap())
}
4 changes: 4 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
super::base::cumsum(tensor, dim)
}
}
4 changes: 4 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,4 +481,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}
}
27 changes: 27 additions & 0 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2263,4 +2263,31 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

out
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
scalar_float_ops!(CumsumOps, B::float_cumsum, usize, noconvert);

let stream = tensor.stream;
let dtype = tensor.dtype;
let shape = tensor.shape.clone();
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
dtype,
NumericOperationDescription::CumSum(desc.clone()),
),
CumsumOps::<B>::new(desc),
);

out
}
}
27 changes: 27 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1819,4 +1819,31 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {

out
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
scalar_int_ops!(CumsumOps, B::int_cumsum, usize, noconvert);

let stream = tensor.stream;
let dtype = tensor.dtype;
let shape = tensor.shape.clone();
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(
dtype,
NumericOperationDescription::CumSum(desc.clone()),
),
CumsumOps::<B>::new(desc),
);

out
}
}
7 changes: 7 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ impl<E: Element> RelativeOpsScalar<E> for NumericOperationDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::CumSum(desc) => {
NumericOperationDescription::CumSum(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs,
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-jit/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,8 @@ where
_ => unimplemented!("Unsupported floating point type cast"),
}
}

fn float_cumsum(_tensor: FloatTensor<Self>, _dim: usize) -> FloatTensor<Self> {
todo!()
}
}
4 changes: 4 additions & 0 deletions crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,4 +283,8 @@ where
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
kernel::flip::<R, I, BT>(tensor, axes)
}

fn int_cumsum(_tensor: IntTensor<Self>, _dim: usize) -> IntTensor<Self> {
todo!()
}
}
10 changes: 10 additions & 0 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ where
NdArrayTensor::from_data(data)
}

pub fn cumsum(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
let mut array = tensor.array.into_owned();
array.accumulate_axis_inplace(Axis(dim), |&prev, curr| {
*curr += prev;
});
let array = array.into_shared();

NdArrayTensor { array }
}

pub fn mean_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
let ndims = tensor.shape().num_dims();
match ndims {
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,8 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps
fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
NdArrayOps::expand(tensor, shape)
}

fn int_cumsum(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::cumsum(tensor, dim)
}
}
4 changes: 4 additions & 0 deletions crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,4 +575,8 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
_ => panic!("Invalid cast types"),
}
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
}
}
19 changes: 19 additions & 0 deletions crates/burn-router/src/ops/op_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1491,4 +1491,23 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {

out
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};

client.register(OperationDescription::NumericFloat(
dtype,
NumericOperationDescription::CumSum(desc),
));

out
}
}
19 changes: 19 additions & 0 deletions crates/burn-router/src/ops/op_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1173,4 +1173,23 @@ impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {

out
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};

client.register(OperationDescription::NumericInt(
dtype,
NumericOperationDescription::CumSum(desc),
));

out
}
}
6 changes: 6 additions & 0 deletions crates/burn-router/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
NumericOperationDescription::Powf(desc) => {
binary_float_ops!(handles, desc, B::float_powf)
}
NumericOperationDescription::CumSum(desc) => {
scalar_float_dim_ops!(handles, desc, B::float_cumsum)
}
},
OperationDescription::NumericInt(_dtype, op) => match op {
NumericOperationDescription::Add(desc) => {
Expand Down Expand Up @@ -764,6 +767,9 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
let output = B::int_powf(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
}
NumericOperationDescription::CumSum(desc) => {
scalar_int_dim_ops!(handles, desc, B::int_cumsum)
}
},
OperationDescription::Bool(op) => match op {
BoolOperationDescription::IntoFloat(desc) => {
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ impl TchOps {
TchTensor::new(tensor)
}

pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),
tensor.storage,
)
}

pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,8 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
TchOps::argsort(tensor, dim, descending)
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
TchOps::cumsum(tensor, dim)
}
}
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,4 +479,8 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
TchTensor::new(tensor.tensor.to_kind(kind))
}
}

fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cumsum(tensor, dim)
}
}
Loading
Loading