From be761b71cc7e5b2dee024af8ac641700967b7d92 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 8 Jan 2025 11:15:10 -0500 Subject: [PATCH] Fix output float dtype --- crates/burn-fusion/src/ops/float.rs | 67 ++++++++++------------------- 1 file changed, 22 insertions(+), 45 deletions(-) diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 57cfbd4132..b3e2a80432 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -278,9 +278,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -323,7 +321,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = ClampOperationDescription { tensor: tensor.into_description(), @@ -375,9 +373,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), rhs: rhs.elem(), @@ -428,9 +424,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -481,9 +475,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -534,9 +526,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -567,9 +557,7 @@ impl FloatTensorOps for Fusion { shape[ndims - 2] = lhs.shape[ndims - 2]; shape[ndims - 1] = rhs.shape[ndims - 1]; - let out = lhs - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(shape, dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), rhs: rhs.into_description(), @@ -601,13 +589,12 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let mut out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let mut out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -641,9 +628,8 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(shape.dims, B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(shape.dims, dtype); let desc = ReshapeDescription { input: tensor.into_description(), @@ -1300,9 +1286,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let dtype = tensor.dtype; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1327,9 +1311,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1352,9 +1334,8 @@ impl FloatTensorOps for Fusion { unary_float_ops!(ProdOps, B::float_prod, reduce); let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1363,7 +1344,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Prod(desc.clone()), ), ProdOps::::new(desc), @@ -1376,11 +1357,10 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(ProdDimOps, B::float_prod_dim, usize, noconvert); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1404,9 +1384,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let dtype = tensor.dtype; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1431,9 +1409,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1716,6 +1692,7 @@ impl FloatTensorOps for Fusion { } let tensor_first = tensors.first().unwrap(); + let dtype = tensor_first.dtype; let client = tensor_first.client.clone(); // Calculate the output shape @@ -1726,7 +1703,7 @@ impl FloatTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = client.tensor_uninitialized(shape, dtype); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(),