Skip to content

Commit

Permalink
Merge pull request #1 from gomlx/while
Browse files Browse the repository at this point in the history
Added While operation.
  • Loading branch information
janpfeifer authored Jul 4, 2024
2 parents 597565e + b6979e8 commit b4c6882
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 48 deletions.
1 change: 1 addition & 0 deletions c/gomlx/xlabuilder/gen_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum OpType {
BatchNormInferenceOp,
BatchNormGradOp,
RngBitGeneratorOp,
WhileOp,
AbsOp,
NegOp,
ExpOp,
Expand Down
7 changes: 7 additions & 0 deletions c/gomlx/xlabuilder/xlabuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ XlaStatus *XlaBuilderAddOp(XlaBuilder *builder, SerializedOp *serialized_op) {
op = xla::Fft(*inputs[0], fft_type, list_of_ints);
break;
}
case WhileOp: {
// Create select and scatter comps.
const xla::XlaComputation &condition_comp = *serialized_op->computation;
const xla::XlaComputation &body_comp = *serialized_op->second_computation;
op = xla::While(condition_comp, body_comp, *inputs[0]);
break;
}

// One-argument ops:
case AbsOp:
Expand Down
13 changes: 13 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@


# Next

* Added `While` op.
* Improved Mandelbrot example.

# v0.0.1 Initial Release

* `xlabuilder` with good coverage: all ops used by [GoMLX](github.com/gomlx/gomlx).
* `pjrt` with enough functionality coverage for [GoMLX](github.com/gomlx/gomlx) and to execute some Jax functions.
* Documentation for API, examples, one notebook (Mandelbrot) and installation details for CUDA.
* Prebuilt cpu pjrt plugin and C/C++ XlaBuilder libraries for `linux/x86-64`.
1 change: 1 addition & 0 deletions xlabuilder/gen_op_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
BatchNormInferenceOp
BatchNormGradOp
RngBitGeneratorOp
WhileOp
AbsOp
NegOp
ExpOp
Expand Down
1 change: 1 addition & 0 deletions xlabuilder/op_types.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ BatchNormTraining:special
BatchNormInference:special
BatchNormGrad:special
RngBitGenerator:special
While:special

// One-argument ops:
Abs:one
Expand Down
97 changes: 49 additions & 48 deletions xlabuilder/optype_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions xlabuilder/shape.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ func (s Shape) Size() int {
return size
}

// Clone makes a deep copy (including dimensions and tuples) of the given shape.
func (s Shape) Clone() (newS Shape) {
newS.DType = s.DType
if len(s.Dimensions) > 0 {
newS.Dimensions = slices.Clone(s.Dimensions)
}
if len(s.TupleShapes) > 0 {
newS.TupleShapes = make([]Shape, len(s.TupleShapes))
for ii, subS := range s.TupleShapes {
newS.TupleShapes[ii] = subS.Clone()
}
}
return newS
}

// TupleSize is an alias to len(Shape.TupleShapes).
func (s Shape) TupleSize() int {
return len(s.TupleShapes)
Expand Down
31 changes: 31 additions & 0 deletions xlabuilder/special_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,34 @@ func DecodeRngBitGenerator(op *Op) (state *Op, shape Shape) {
shape = op.ShapeArg
return
}

// While executes a loop in the computation.
//
// It takes as input:
//
// - initialState: usually a tuple, that includes all variables used by condition and body.
// - condition: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs
// a bool (dtypes.PRED) whether the loop should keep iterating.
// - body: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs
// an updated state.
//
// See details in https://openxla.org/xla/operation_semantics#while
func While(initialState *Op, condition, body *XlaComputation) (*Op, error) {
builder := initialState.builder
op := newOp(WhileOp, initialState)
op.ComputationArg = condition
op.SecondComputationArg = body
err := builder.addOp(op)
if err != nil {
return nil, err
}
return op, nil
}

// DecodeWhile retrieves the arguments for the While op.
func DecodeWhile(op *Op) (initialState *Op, condition, body *XlaComputation) {
initialState = op.OpInputs[0]
condition = op.ComputationArg
body = op.SecondComputationArg
return
}
54 changes: 54 additions & 0 deletions xlabuilder/special_ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,57 @@ func TestReverse(t *testing.T) {
builder.Destroy()
}
}

func TestWhile(t *testing.T) {
client := getPJRTClient(t)

builder := New(t.Name())
dtype := dtypes.Int64
x := capture(Parameter(builder, "x", 0, MakeShape(dtype))).Test(t)
value := capture(ScalarOne(builder, dtype)).Test(t)

// Calcualte factorial using a While loop.
// While loop:
// - initialState: (x, 1)
initialState := capture(Tuple(x, value)).Test(t)
var cond, body *XlaComputation
// - condition: x > 0
{
condBuilder := builder.CreateSubBuilder(t.Name() + "_condition")
tuple := capture(Parameter(condBuilder, "tuple", 0, initialState.Shape.Clone())).Test(t)
loopX := capture(GetTupleElement(tuple, 0)).Test(t)
zero := capture(ScalarZero(condBuilder, dtype)).Test(t)
output := capture(GreaterThan(loopX, zero)).Test(t)
cond = capture(condBuilder.Build(output)).Test(t)
}
// - body: value = value * x; x = x-1;
{
bodyBuilder := builder.CreateSubBuilder(t.Name() + "_body")
tuple := capture(Parameter(bodyBuilder, "tuple", 0, initialState.Shape.Clone())).Test(t)
loopX := capture(GetTupleElement(tuple, 0)).Test(t)
loopValue := capture(GetTupleElement(tuple, 1)).Test(t)
loopValue = capture(Mul(loopValue, loopX)).Test(t)
one := capture(ScalarOne(bodyBuilder, dtype)).Test(t)
loopX = capture(Sub(loopX, one)).Test(t)
output := capture(Tuple(loopX, loopValue)).Test(t)
body = capture(bodyBuilder.Build(output)).Test(t)
}
state := capture(While(initialState, cond, body)).Test(t)

gotInitialState, gotCond, gotBody := DecodeWhile(state)
require.Same(t, initialState, gotInitialState)
require.Same(t, cond, gotCond)
require.Same(t, body, gotBody)

output := capture(GetTupleElement(state, 1)).Test(t)
exec := compile(t, client, capture(builder.Build(output)).Test(t))

// 5! = 120
got := int(execWithScalars(t, client, exec, int64(5)))
require.Equal(t, 120, got)

// 7! = 5040
got = int(execWithScalars(t, client, exec, int64(7)))
require.Equal(t, 5040, got)
builder.Destroy()
}

0 comments on commit b4c6882

Please sign in to comment.