Skip to content

Commit

Permalink
Added dtype matching checks everywhere.
Browse files Browse the repository at this point in the history
  • Loading branch information
janpfeifer committed Aug 11, 2024
1 parent 8969569 commit bd3f615
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cmd/xlabuilder_codegen/go_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ func {{.Name}}(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of {{.Name}}(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp({{.Name}}Op, x0, x1)
err := builder.addOp(y)
Expand Down
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Use `github.com/dmarkham/enumer` instead of the usual `stringer` for dtypes.
* Fixed double free of C.XlaOp pointers for Identity ops.
* Added `DynamicSlice` and `DynamicSliceUpdate`.
* Added check for matching DTypes for the common ops taking 2 operands.

# v0.2.0 GoMLX integration fixes -- GoMLX more extensive tests caught several small issues in Gopjrt.

Expand Down
75 changes: 75 additions & 0 deletions xlabuilder/gen_simple_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ func Add(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Add(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(AddOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -281,6 +284,9 @@ func Mul(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Mul(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(MulOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -297,6 +303,9 @@ func Sub(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Sub(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(SubOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -313,6 +322,9 @@ func Div(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Div(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(DivOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -329,6 +341,9 @@ func Rem(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Rem(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(RemOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -344,6 +359,9 @@ func And(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of And(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(AndOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -359,6 +377,9 @@ func Or(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Or(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(OrOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -374,6 +395,9 @@ func Xor(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Xor(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(XorOp, x0, x1)
err := builder.addOp(y)
Expand Down Expand Up @@ -402,6 +426,9 @@ func Dot(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Dot(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(DotOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -417,6 +444,9 @@ func Min(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Min(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(MinOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -432,6 +462,9 @@ func Max(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Max(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(MaxOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -447,6 +480,9 @@ func Pow(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Pow(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(PowOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -467,6 +503,9 @@ func Complex(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Complex(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(ComplexOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -482,6 +521,9 @@ func Equal(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of Equal(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(EqualOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -497,6 +539,9 @@ func NotEqual(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of NotEqual(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(NotEqualOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -512,6 +557,9 @@ func GreaterOrEqual(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of GreaterOrEqual(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(GreaterOrEqualOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -527,6 +575,9 @@ func GreaterThan(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of GreaterThan(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(GreaterThanOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -542,6 +593,9 @@ func LessOrEqual(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of LessOrEqual(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(LessOrEqualOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -557,6 +611,9 @@ func LessThan(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of LessThan(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(LessThanOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -576,6 +633,9 @@ func EqualTotalOrder(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of EqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(EqualTotalOrderOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -595,6 +655,9 @@ func NotEqualTotalOrder(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of NotEqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(NotEqualTotalOrderOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -614,6 +677,9 @@ func GreaterOrEqualTotalOrder(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of GreaterOrEqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(GreaterOrEqualTotalOrderOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -633,6 +699,9 @@ func GreaterThanTotalOrder(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of GreaterThanTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(GreaterThanTotalOrderOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -652,6 +721,9 @@ func LessOrEqualTotalOrder(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of LessOrEqualTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(LessOrEqualTotalOrderOp, x0, x1)
err := builder.addOp(y)
Expand All @@ -671,6 +743,9 @@ func LessThanTotalOrder(x0, x1 *Op) (*Op, error) {
if x0.builder != x1.builder {
return nil, errors.New("arguments of LessThanTotalOrder(x0, x1) come from different XlaBuilder objects (or nil)")
}
if x0.Shape.DType != x1.Shape.DType {
return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
}
builder := x0.builder
y := newOp(LessThanTotalOrderOp, x0, x1)
err := builder.addOp(y)
Expand Down
16 changes: 16 additions & 0 deletions xlabuilder/special_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ func Where(condition, onTrue, onFalse *Op) (*Op, error) {
return nil, errors.New("trying to access XlaBuilder that is nil or already destroyed")
}
builder := condition.builder
if onTrue.Shape.DType != onFalse.Shape.DType {
return nil, errors.Errorf("dtype of onTrue (%s) and onFalse (%s) don't match", onTrue.Shape.DType, onFalse.Shape.DType)
}
op := newOp(WhereOp, condition, onTrue, onFalse)
err := builder.addOp(op)
if err != nil {
Expand Down Expand Up @@ -386,6 +389,13 @@ func Concatenate(axis int, operands ...*Op) (*Op, error) {
// Trivial solution.
return operands[0], nil
}
dtype := operands[0].Shape.DType
for ii, op := range operands {
if op.Shape.DType != dtype {
return nil, errors.Errorf("Concatenate operand 0 has dtype %s, by operand %d has dtype %s: dtypes must match",
dtype, ii, op.Shape.DType)
}
}
builder := operands[0].builder
op := newOp(ConcatenateOp, operands...)
op.IntArg = axis
Expand Down Expand Up @@ -508,6 +518,9 @@ func Pad(x, fillValue *Op, axesConfig ...PadAxis) (*Op, error) {
if rank == 0 {
return nil, errors.New("cannot use Pad() with scalar values")
}
if x.Shape.DType != fillValue.Shape.DType {
return nil, errors.Errorf("operand and fillValue dtypes (%s and %s) don't match for Pad()", x.Shape.DType, fillValue.Shape.DType)
}
op := newOp(PadOp, x, fillValue)
op.IntsArg = make([]int, 0, 3*rank)
for axis := 0; axis < rank; axis++ {
Expand Down Expand Up @@ -1331,6 +1344,9 @@ func DecodeDynamicSlice(op *Op) (operand *Op, startIndices []*Op, sliceDims []in
// See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice
func DynamicUpdateSlice(operand, update *Op, startIndices []*Op) (*Op, error) {
builder := operand.builder
if operand.Shape.DType != update.Shape.DType {
return nil, errors.Errorf("operand and update dtypes (%s and %s) don't match for DynamicUpdateSlice", operand.Shape.DType, update.Shape.DType)
}
allOps := append([]*Op{operand, update}, startIndices...)
op := newOp(DynamicUpdateSliceOp, allOps...)
err := builder.addOp(op)
Expand Down

0 comments on commit bd3f615

Please sign in to comment.