Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added While operation. #1

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c/gomlx/xlabuilder/gen_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum OpType {
BatchNormInferenceOp,
BatchNormGradOp,
RngBitGeneratorOp,
WhileOp,
AbsOp,
NegOp,
ExpOp,
Expand Down
7 changes: 7 additions & 0 deletions c/gomlx/xlabuilder/xlabuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ XlaStatus *XlaBuilderAddOp(XlaBuilder *builder, SerializedOp *serialized_op) {
op = xla::Fft(*inputs[0], fft_type, list_of_ints);
break;
}
case WhileOp: {
// Create select and scatter comps.
const xla::XlaComputation &condition_comp = *serialized_op->computation;
const xla::XlaComputation &body_comp = *serialized_op->second_computation;
op = xla::While(condition_comp, body_comp, *inputs[0]);
break;
}

// One-argument ops:
case AbsOp:
Expand Down
13 changes: 13 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@


# Next

* Added `While` op.
* Improved Mandelbrot example.

# v0.0.1 Initial Release

* `xlabuilder` with good coverage: all ops used by [GoMLX](github.com/gomlx/gomlx).
* `pjrt` with enough functionality coverage for [GoMLX](github.com/gomlx/gomlx) and to execute some Jax functions.
* Documentation for API, examples, one notebook (Mandelbrot) and installation details for CUDA.
* Prebuilt cpu pjrt plugin and C/C++ XlaBuilder libraries for `linux/x86-64`.
1 change: 1 addition & 0 deletions xlabuilder/gen_op_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
BatchNormInferenceOp
BatchNormGradOp
RngBitGeneratorOp
WhileOp
AbsOp
NegOp
ExpOp
Expand Down
1 change: 1 addition & 0 deletions xlabuilder/op_types.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ BatchNormTraining:special
BatchNormInference:special
BatchNormGrad:special
RngBitGenerator:special
While:special

// One-argument ops:
Abs:one
Expand Down
97 changes: 49 additions & 48 deletions xlabuilder/optype_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions xlabuilder/shape.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ func (s Shape) Size() int {
return size
}

// Clone makes a deep copy (including dimensions and tuples) of the given shape.
func (s Shape) Clone() (newS Shape) {
newS.DType = s.DType
if len(s.Dimensions) > 0 {
newS.Dimensions = slices.Clone(s.Dimensions)
}
if len(s.TupleShapes) > 0 {
newS.TupleShapes = make([]Shape, len(s.TupleShapes))
for ii, subS := range s.TupleShapes {
newS.TupleShapes[ii] = subS.Clone()
}
}
return newS
}

// TupleSize is an alias to len(Shape.TupleShapes).
func (s Shape) TupleSize() int {
return len(s.TupleShapes)
Expand Down
31 changes: 31 additions & 0 deletions xlabuilder/special_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,34 @@ func DecodeRngBitGenerator(op *Op) (state *Op, shape Shape) {
shape = op.ShapeArg
return
}

// While executes a loop in the computation.
//
// It takes as input:
//
// - initialState: usually a tuple, that includes all variables used by condition and body.
// - condition: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs
// a bool (dtypes.PRED) whether the loop should keep iterating.
// - body: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs
// an updated state.
//
// See details in https://openxla.org/xla/operation_semantics#while
func While(initialState *Op, condition, body *XlaComputation) (*Op, error) {
builder := initialState.builder
op := newOp(WhileOp, initialState)
op.ComputationArg = condition
op.SecondComputationArg = body
err := builder.addOp(op)
if err != nil {
return nil, err
}
return op, nil
}

// DecodeWhile retrieves the arguments for the While op.
func DecodeWhile(op *Op) (initialState *Op, condition, body *XlaComputation) {
initialState = op.OpInputs[0]
condition = op.ComputationArg
body = op.SecondComputationArg
return
}
54 changes: 54 additions & 0 deletions xlabuilder/special_ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,57 @@ func TestReverse(t *testing.T) {
builder.Destroy()
}
}

func TestWhile(t *testing.T) {
client := getPJRTClient(t)

builder := New(t.Name())
dtype := dtypes.Int64
x := capture(Parameter(builder, "x", 0, MakeShape(dtype))).Test(t)
value := capture(ScalarOne(builder, dtype)).Test(t)

// Calcualte factorial using a While loop.
// While loop:
// - initialState: (x, 1)
initialState := capture(Tuple(x, value)).Test(t)
var cond, body *XlaComputation
// - condition: x > 0
{
condBuilder := builder.CreateSubBuilder(t.Name() + "_condition")
tuple := capture(Parameter(condBuilder, "tuple", 0, initialState.Shape.Clone())).Test(t)
loopX := capture(GetTupleElement(tuple, 0)).Test(t)
zero := capture(ScalarZero(condBuilder, dtype)).Test(t)
output := capture(GreaterThan(loopX, zero)).Test(t)
cond = capture(condBuilder.Build(output)).Test(t)
}
// - body: value = value * x; x = x-1;
{
bodyBuilder := builder.CreateSubBuilder(t.Name() + "_body")
tuple := capture(Parameter(bodyBuilder, "tuple", 0, initialState.Shape.Clone())).Test(t)
loopX := capture(GetTupleElement(tuple, 0)).Test(t)
loopValue := capture(GetTupleElement(tuple, 1)).Test(t)
loopValue = capture(Mul(loopValue, loopX)).Test(t)
one := capture(ScalarOne(bodyBuilder, dtype)).Test(t)
loopX = capture(Sub(loopX, one)).Test(t)
output := capture(Tuple(loopX, loopValue)).Test(t)
body = capture(bodyBuilder.Build(output)).Test(t)
}
state := capture(While(initialState, cond, body)).Test(t)

gotInitialState, gotCond, gotBody := DecodeWhile(state)
require.Same(t, initialState, gotInitialState)
require.Same(t, cond, gotCond)
require.Same(t, body, gotBody)

output := capture(GetTupleElement(state, 1)).Test(t)
exec := compile(t, client, capture(builder.Build(output)).Test(t))

// 5! = 120
got := int(execWithScalars(t, client, exec, int64(5)))
require.Equal(t, 120, got)

// 7! = 5040
got = int(execWithScalars(t, client, exec, int64(7)))
require.Equal(t, 5040, got)
builder.Destroy()
}
Loading