diff --git a/cmd/xlabuilder_codegen/go_ops.go b/cmd/xlabuilder_codegen/go_ops.go index e78dba3..215c058 100644 --- a/cmd/xlabuilder_codegen/go_ops.go +++ b/cmd/xlabuilder_codegen/go_ops.go @@ -43,6 +43,9 @@ func {{.Name}}(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of {{.Name}}(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp({{.Name}}Op, x0, x1) err := builder.addOp(y) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 7d3174f..69d7533 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -5,6 +5,7 @@ * Use `github.com/dmarkham/enumer` instead of the usual `stringer` for dtypes. * Fixed double free of C.XlaOp pointers for Identity ops. * Added `DynamicSlice` and `DynamicSliceUpdate`. +* Added check for matching DTypes for the common ops taking 2 operands. # v0.2.0 GoMLX integration fixes -- GoMLX more extensive tests caught several small issues in Gopjrt. diff --git a/xlabuilder/gen_simple_ops.go b/xlabuilder/gen_simple_ops.go index 6f61777..21e4c56 100644 --- a/xlabuilder/gen_simple_ops.go +++ b/xlabuilder/gen_simple_ops.go @@ -265,6 +265,9 @@ func Add(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Add(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(AddOp, x0, x1) err := builder.addOp(y) @@ -281,6 +284,9 @@ func Mul(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Mul(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(MulOp, x0, x1) err := builder.addOp(y) @@ -297,6 +303,9 @@ func Sub(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Sub(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(SubOp, x0, x1) err := builder.addOp(y) @@ -313,6 +322,9 @@ func Div(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Div(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(DivOp, x0, x1) err := builder.addOp(y) @@ -329,6 +341,9 @@ func Rem(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Rem(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(RemOp, x0, x1) err := builder.addOp(y) @@ -344,6 +359,9 @@ func And(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of And(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(AndOp, x0, x1) err := builder.addOp(y) @@ -359,6 +377,9 @@ func Or(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Or(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(OrOp, x0, x1) err := builder.addOp(y) @@ -374,6 +395,9 @@ func Xor(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Xor(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(XorOp, x0, x1) err := builder.addOp(y) @@ -402,6 +426,9 @@ func Dot(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Dot(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(DotOp, x0, x1) err := builder.addOp(y) @@ -417,6 +444,9 @@ func Min(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Min(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(MinOp, x0, x1) err := builder.addOp(y) @@ -432,6 +462,9 @@ func Max(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Max(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(MaxOp, x0, x1) err := builder.addOp(y) @@ -447,6 +480,9 @@ func Pow(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Pow(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(PowOp, x0, x1) err := builder.addOp(y) @@ -467,6 +503,9 @@ func Complex(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Complex(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(ComplexOp, x0, x1) err := builder.addOp(y) @@ -482,6 +521,9 @@ func Equal(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of Equal(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(EqualOp, x0, x1) err := builder.addOp(y) @@ -497,6 +539,9 @@ func NotEqual(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of NotEqual(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(NotEqualOp, x0, x1) err := builder.addOp(y) @@ -512,6 +557,9 @@ func GreaterOrEqual(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of GreaterOrEqual(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(GreaterOrEqualOp, x0, x1) err := builder.addOp(y) @@ -527,6 +575,9 @@ func GreaterThan(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of GreaterThan(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(GreaterThanOp, x0, x1) err := builder.addOp(y) @@ -542,6 +593,9 @@ func LessOrEqual(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of LessOrEqual(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(LessOrEqualOp, x0, x1) err := builder.addOp(y) @@ -557,6 +611,9 @@ func LessThan(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of LessThan(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(LessThanOp, x0, x1) err := builder.addOp(y) @@ -576,6 +633,9 @@ func EqualTotalOrder(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of EqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(EqualTotalOrderOp, x0, x1) err := builder.addOp(y) @@ -595,6 +655,9 @@ func NotEqualTotalOrder(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of NotEqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(NotEqualTotalOrderOp, x0, x1) err := builder.addOp(y) @@ -614,6 +677,9 @@ func GreaterOrEqualTotalOrder(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of GreaterOrEqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(GreaterOrEqualTotalOrderOp, x0, x1) err := builder.addOp(y) @@ -633,6 +699,9 @@ func GreaterThanTotalOrder(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of GreaterThanTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(GreaterThanTotalOrderOp, x0, x1) err := builder.addOp(y) @@ -652,6 +721,9 @@ func LessOrEqualTotalOrder(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of LessOrEqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(LessOrEqualTotalOrderOp, x0, x1) err := builder.addOp(y) @@ -671,6 +743,9 @@ func LessThanTotalOrder(x0, x1 *Op) (*Op, error) { if x0.builder != x1.builder { return nil, errors.New("arguments of LessThanTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)") } + if x0.Shape.DType != x1.Shape.DType { + return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType) + } builder := x0.builder y := newOp(LessThanTotalOrderOp, x0, x1) err := builder.addOp(y) diff --git a/xlabuilder/special_ops.go b/xlabuilder/special_ops.go index 7e1ef07..63cc215 100644 --- a/xlabuilder/special_ops.go +++ b/xlabuilder/special_ops.go @@ -203,6 +203,9 @@ func Where(condition, onTrue, onFalse *Op) (*Op, error) { return nil, errors.New("trying to access XlaBuilder that is nil or already destroyed") } builder := condition.builder + if onTrue.Shape.DType != onFalse.Shape.DType { + return nil, errors.Errorf("dtype of onTrue (%s) and onFalse (%s) don't match", onTrue.Shape.DType, onFalse.Shape.DType) + } op := newOp(WhereOp, condition, onTrue, onFalse) err := builder.addOp(op) if err != nil { @@ -386,6 +389,13 @@ func Concatenate(axis int, operands ...*Op) (*Op, error) { // Trivial solution. return operands[0], nil } + dtype := operands[0].Shape.DType + for ii, op := range operands { + if op.Shape.DType != dtype { + return nil, errors.Errorf("Concatenate operand 0 has dtype %s, by operand %d has dtype %s: dtypes must match", + dtype, ii, op.Shape.DType) + } + } builder := operands[0].builder op := newOp(ConcatenateOp, operands...) op.IntArg = axis @@ -508,6 +518,9 @@ func Pad(x, fillValue *Op, axesConfig ...PadAxis) (*Op, error) { if rank == 0 { return nil, errors.New("cannot use Pad() with scalar values") } + if x.Shape.DType != fillValue.Shape.DType { + return nil, errors.Errorf("operand and fillValue dtypes (%s and %s) don't match for Pad()", x.Shape.DType, fillValue.Shape.DType) + } op := newOp(PadOp, x, fillValue) op.IntsArg = make([]int, 0, 3*rank) for axis := 0; axis < rank; axis++ { @@ -1331,6 +1344,9 @@ func DecodeDynamicSlice(op *Op) (operand *Op, startIndices []*Op, sliceDims []in // See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice func DynamicUpdateSlice(operand, update *Op, startIndices []*Op) (*Op, error) { builder := operand.builder + if operand.Shape.DType != update.Shape.DType { + return nil, errors.Errorf("operand and update dtypes (%s and %s) don't match for DynamicUpdateSlice", operand.Shape.DType, update.Shape.DType) + } allOps := append([]*Op{operand, update}, startIndices...) op := newOp(DynamicUpdateSliceOp, allOps...) err := builder.addOp(op)