Skip to content

Commit

Permalink
Fix output float dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 8, 2025
1 parent e2fa9c4 commit be761b7
Showing 1 changed file with 22 additions and 45 deletions.
67 changes: 22 additions & 45 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = lhs.stream;
let dtype = lhs.dtype;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
Expand Down Expand Up @@ -323,7 +321,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let dtype = tensor.dtype;
let out = tensor
.client
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
.tensor_uninitialized(tensor.shape.clone(), dtype);

let desc = ClampOperationDescription {
tensor: tensor.into_description(),
Expand Down Expand Up @@ -375,9 +373,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = lhs.stream;
let dtype = lhs.dtype;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
Expand Down Expand Up @@ -428,9 +424,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = lhs.stream;
let dtype = lhs.dtype;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
Expand Down Expand Up @@ -481,9 +475,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = lhs.stream;
let dtype = lhs.dtype;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
Expand Down Expand Up @@ -534,9 +526,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = lhs.stream;
let dtype = lhs.dtype;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
Expand Down Expand Up @@ -567,9 +557,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
shape[ndims - 2] = lhs.shape[ndims - 2];
shape[ndims - 1] = rhs.shape[ndims - 1];

let out = lhs
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let out = lhs.client.tensor_uninitialized(shape, dtype);
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
Expand Down Expand Up @@ -601,13 +589,12 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}

let stream = tensor.stream;
let dtype = tensor.dtype;
let mut shape = tensor.shape.clone();
shape[dim1] = tensor.shape[dim2];
shape[dim2] = tensor.shape[dim1];

let mut out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let mut out = tensor.client.tensor_uninitialized(shape, dtype);

let desc = SwapDimsDescription {
input: tensor.into_description(),
Expand Down Expand Up @@ -641,9 +628,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}

let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(shape.dims, B::FloatElem::dtype());
let dtype = tensor.dtype;
let out = tensor.client.tensor_uninitialized(shape.dims, dtype);

let desc = ReshapeDescription {
input: tensor.into_description(),
Expand Down Expand Up @@ -1300,9 +1286,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = tensor.stream;
let dtype = tensor.dtype;
let out = tensor
.client
.tensor_uninitialized(vec![1], B::FloatElem::dtype());
let out = tensor.client.tensor_uninitialized(vec![1], dtype);

let desc = UnaryOperationDescription {
input: tensor.into_description(),
Expand All @@ -1327,9 +1311,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let dtype = tensor.dtype;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let out = tensor.client.tensor_uninitialized(shape, dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
Expand All @@ -1352,9 +1334,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
unary_float_ops!(ProdOps, B::float_prod, reduce);

let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(vec![1], B::FloatElem::dtype());
let dtype = tensor.dtype;
let out = tensor.client.tensor_uninitialized(vec![1], dtype);

let desc = UnaryOperationDescription {
input: tensor.into_description(),
Expand All @@ -1363,7 +1344,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
dtype,
NumericOperationDescription::Prod(desc.clone()),
),
ProdOps::<B>::new(desc),
Expand All @@ -1376,11 +1357,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
scalar_float_ops!(ProdDimOps, B::float_prod_dim, usize, noconvert);

let stream = tensor.stream;
let dtype = tensor.dtype;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let out = tensor.client.tensor_uninitialized(shape, dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
Expand All @@ -1404,9 +1384,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = tensor.stream;
let dtype = tensor.dtype;
let out = tensor
.client
.tensor_uninitialized(vec![1], B::FloatElem::dtype());
let out = tensor.client.tensor_uninitialized(vec![1], dtype);

let desc = UnaryOperationDescription {
input: tensor.into_description(),
Expand All @@ -1431,9 +1409,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let dtype = tensor.dtype;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let out = tensor.client.tensor_uninitialized(shape, dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
Expand Down Expand Up @@ -1716,6 +1692,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}

let tensor_first = tensors.first().unwrap();
let dtype = tensor_first.dtype;
let client = tensor_first.client.clone();

// Calculate the output shape
Expand All @@ -1726,7 +1703,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
shape[dim] += tensor.shape[dim];
}

let out = client.tensor_uninitialized(shape, B::FloatElem::dtype());
let out = client.tensor_uninitialized(shape, dtype);

let desc = CatOperationDescription {
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
Expand Down

0 comments on commit be761b7

Please sign in to comment.