Skip to content

Commit

Permalink
patch: address-requested-changes
Browse files Browse the repository at this point in the history
  • Loading branch information
quinton11 committed Dec 28, 2024
1 parent 7a2108c commit 78fb103
Show file tree
Hide file tree
Showing 22 changed files with 391 additions and 623 deletions.
14 changes: 7 additions & 7 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,30 +373,30 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
sign(tensor)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!("bitwise_and is not implemented for Candle IntTensor");
unimplemented!("bitwise_and is not implemented for Candle IntTensor");
}

fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
todo!("bitwise_and_scalar is not implemented for Candle IntTensor");
unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor");
}

fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!("bitwise_or is not implemented for Candle IntTensor");
unimplemented!("bitwise_or is not implemented for Candle IntTensor");
}

fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
todo!("bitwise_or_scalar is not implemented for Candle IntTensor");
unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor");
}

fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!("bitwise_xor is not implemented for Candle IntTensor");
unimplemented!("bitwise_xor is not implemented for Candle IntTensor");
}

fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
todo!("bitwise_xor_scalar is not implemented for Candle IntTensor");
unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor");
}

fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
todo!("bitwise_not is not implemented for Candle IntTensor");
unimplemented!("bitwise_not is not implemented for Candle IntTensor");
}
}
33 changes: 18 additions & 15 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Int(repr::IntOperationDescription::IntoFloat(desc.clone())),
OperationDescription::Int(
IntElem::<Self>::dtype(),
repr::IntOperationDescription::IntoFloat(desc.clone()),
),
IntoFloatOps::<B>::new(desc),
);

Expand Down Expand Up @@ -1837,9 +1840,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseAnd(desc.clone()),
IntOperationDescription::BitwiseAnd(desc.clone()),
),
BitwiseAndOps::<B>::new(desc),
);
Expand All @@ -1862,9 +1865,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseAndScalar(desc.clone()),
IntOperationDescription::BitwiseAndScalar(desc.clone()),
),
BitwiseAndOps::<B>::new(desc),
);
Expand All @@ -1889,9 +1892,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseOr(desc.clone()),
IntOperationDescription::BitwiseOr(desc.clone()),
),
BitwiseOrOps::<B>::new(desc),
);
Expand All @@ -1914,9 +1917,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseOrScalar(desc.clone()),
IntOperationDescription::BitwiseOrScalar(desc.clone()),
),
BitwiseOrOps::<B>::new(desc),
);
Expand All @@ -1941,9 +1944,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseXor(desc.clone()),
IntOperationDescription::BitwiseXor(desc.clone()),
),
BitwiseXorOps::<B>::new(desc),
);
Expand All @@ -1966,9 +1969,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseXorScalar(desc.clone()),
IntOperationDescription::BitwiseXorScalar(desc.clone()),
),
BitwiseXorOps::<B>::new(desc),
);
Expand All @@ -1990,9 +1993,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(
repr::OperationDescription::Int(
IntElem::<Self>::dtype(),
NumericOperationDescription::BitwiseNot(desc.clone()),
IntOperationDescription::BitwiseNot(desc.clone()),
),
BitwiseNotOps::<B>::new(desc),
);
Expand Down
108 changes: 57 additions & 51 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ impl RelativeOps for OperationDescription {
OperationDescription::Bool(ops) => {
OperationDescription::Bool(ops.to_relative(converter))
}
OperationDescription::Int(ops) => OperationDescription::Int(ops.to_relative(converter)),
OperationDescription::Int(dtype, ops) => OperationDescription::Int(
*dtype,
ops.to_relative(converter, |converter, e| converter.relative_int(e, dtype)),
),
OperationDescription::Float(dtype, ops) => OperationDescription::Float(
*dtype,
RelativeOpsScalar::<f32>::to_relative(ops, converter, |converter, e| {
Expand Down Expand Up @@ -607,15 +610,66 @@ impl RelativeOps for BoolOperationDescription {
}
}

impl RelativeOps for IntOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl<E: Element> RelativeOpsScalar<E> for IntOperationDescription<E> {
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
where
F: Fn(&mut OperationConverter, &E) -> E,
{
match self {
IntOperationDescription::IntoFloat(desc) => {
IntOperationDescription::IntoFloat(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseAnd(desc) => {
IntOperationDescription::BitwiseAnd(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseAndScalar(desc) => {
IntOperationDescription::BitwiseAndScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseOr(desc) => {
IntOperationDescription::BitwiseOr(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseOrScalar(desc) => {
IntOperationDescription::BitwiseOrScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseXor(desc) => {
IntOperationDescription::BitwiseXor(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseXorScalar(desc) => {
IntOperationDescription::BitwiseXorScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
IntOperationDescription::BitwiseNot(desc) => {
IntOperationDescription::BitwiseNot(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down Expand Up @@ -961,54 +1015,6 @@ impl<E: Element> RelativeOpsScalar<E> for NumericOperationDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseAnd(desc) => {
NumericOperationDescription::BitwiseAnd(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseAndScalar(desc) => {
NumericOperationDescription::BitwiseAndScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseOr(desc) => {
NumericOperationDescription::BitwiseOr(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseOrScalar(desc) => {
NumericOperationDescription::BitwiseOrScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseXor(desc) => {
NumericOperationDescription::BitwiseXor(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseXorScalar(desc) => {
NumericOperationDescription::BitwiseXorScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::BitwiseNot(desc) => {
NumericOperationDescription::BitwiseNot(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl<N: Float> BinaryOp<N> for PowOp {
impl<N: Numeric> BinaryOp<N> for BitwiseAndOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
//lhs + rhs
lhs
lhs + rhs
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,19 @@ where
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
// not implemented
//todo!()
numeric::bitwise_and(lhs, rhs)
numeric::bitwise_and::<R, I>(lhs, rhs)
}

fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
// not implemented
//todo!()
numeric::bitwise_and_scalar(lhs, rhs)
numeric::bitwise_and_scalar::<R, I>(lhs, rhs)
}

fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
// not implemented
//todo!()
numeric::bitwise_or(lhs, rhs)
numeric::bitwise_or::<R, I>(lhs, rhs)
}

fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
Expand All @@ -311,7 +311,7 @@ where
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
// not implemented
//todo!()
numeric::bitwise_xor(lhs, rhs)
numeric::bitwise_xor::<R, I>(lhs, rhs)
}

fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
Expand All @@ -323,6 +323,6 @@ where
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
// not implemented
//todo!()
numeric::bitwise_not(tensor)
numeric::bitwise_not::<R, I>(tensor)
}
}
Loading

0 comments on commit 78fb103

Please sign in to comment.