diff --git a/c/gomlx/xlabuilder/gen_op_types.h b/c/gomlx/xlabuilder/gen_op_types.h index d51dda4..563a6bb 100644 --- a/c/gomlx/xlabuilder/gen_op_types.h +++ b/c/gomlx/xlabuilder/gen_op_types.h @@ -40,6 +40,7 @@ enum OpType { BatchNormInferenceOp, BatchNormGradOp, RngBitGeneratorOp, + WhileOp, AbsOp, NegOp, ExpOp, diff --git a/c/gomlx/xlabuilder/xlabuilder.cpp b/c/gomlx/xlabuilder/xlabuilder.cpp index 5197138..8d7de3c 100644 --- a/c/gomlx/xlabuilder/xlabuilder.cpp +++ b/c/gomlx/xlabuilder/xlabuilder.cpp @@ -437,6 +437,13 @@ XlaStatus *XlaBuilderAddOp(XlaBuilder *builder, SerializedOp *serialized_op) { op = xla::Fft(*inputs[0], fft_type, list_of_ints); break; } + case WhileOp: { + // Create select and scatter comps. + const xla::XlaComputation &condition_comp = *serialized_op->computation; + const xla::XlaComputation &body_comp = *serialized_op->second_computation; + op = xla::While(condition_comp, body_comp, *inputs[0]); + break; + } // One-argument ops: case AbsOp: diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md new file mode 100644 index 0000000..f99898c --- /dev/null +++ b/docs/CHANGELOG.md @@ -0,0 +1,13 @@ + + +# Next + +* Added `While` op. +* Improved Mandelbrot example. + +# v0.0.1 Initial Release + +* `xlabuilder` with good coverage: all ops used by [GoMLX](github.com/gomlx/gomlx). +* `pjrt` with enough functionality coverage for [GoMLX](github.com/gomlx/gomlx) and to execute some Jax functions. +* Documentation for API, examples, one notebook (Mandelbrot) and installation details for CUDA. +* Prebuilt cpu pjrt plugin and C/C++ XlaBuilder libraries for `linux/x86-64`. \ No newline at end of file diff --git a/xlabuilder/gen_op_types.go b/xlabuilder/gen_op_types.go index 69ab5fc..bf6b714 100644 --- a/xlabuilder/gen_op_types.go +++ b/xlabuilder/gen_op_types.go @@ -37,6 +37,7 @@ const ( BatchNormInferenceOp BatchNormGradOp RngBitGeneratorOp + WhileOp AbsOp NegOp ExpOp diff --git a/xlabuilder/op_types.txt b/xlabuilder/op_types.txt index ea49155..c0035c2 100644 --- a/xlabuilder/op_types.txt +++ b/xlabuilder/op_types.txt @@ -45,6 +45,7 @@ BatchNormTraining:special BatchNormInference:special BatchNormGrad:special RngBitGenerator:special +While:special // One-argument ops: Abs:one diff --git a/xlabuilder/optype_string.go b/xlabuilder/optype_string.go index b956528..3204cbe 100644 --- a/xlabuilder/optype_string.go +++ b/xlabuilder/optype_string.go @@ -39,57 +39,58 @@ func _() { _ = x[BatchNormInferenceOp-28] _ = x[BatchNormGradOp-29] _ = x[RngBitGeneratorOp-30] - _ = x[AbsOp-31] - _ = x[NegOp-32] - _ = x[ExpOp-33] - _ = x[Expm1Op-34] - _ = x[FloorOp-35] - _ = x[CeilOp-36] - _ = x[RoundOp-37] - _ = x[LogOp-38] - _ = x[Log1pOp-39] - _ = x[LogicalNotOp-40] - _ = x[LogisticOp-41] - _ = x[SignOp-42] - _ = x[ClzOp-43] - _ = x[CosOp-44] - _ = x[SinOp-45] - _ = x[TanhOp-46] - _ = x[SqrtOp-47] - _ = x[RsqrtOp-48] - _ = x[ImagOp-49] - _ = x[RealOp-50] - _ = x[ConjOp-51] - _ = x[AddOp-52] - _ = x[MulOp-53] - _ = x[SubOp-54] - _ = x[DivOp-55] - _ = x[RemOp-56] - _ = x[AndOp-57] - _ = x[OrOp-58] - _ = x[XorOp-59] - _ = x[DotOp-60] - _ = x[MinOp-61] - _ = x[MaxOp-62] - _ = x[PowOp-63] - _ = x[ComplexOp-64] - _ = x[EqualOp-65] - _ = x[NotEqualOp-66] - _ = x[GreaterOrEqualOp-67] - _ = x[GreaterThanOp-68] - _ = x[LessOrEqualOp-69] - _ = x[LessThanOp-70] - _ = x[EqualTotalOrderOp-71] - _ = x[NotEqualTotalOrderOp-72] - _ = x[GreaterOrEqualTotalOrderOp-73] - _ = x[GreaterThanTotalOrderOp-74] - _ = x[LessOrEqualTotalOrderOp-75] - _ = x[LessThanTotalOrderOp-76] + _ = x[WhileOp-31] + _ = x[AbsOp-32] + _ = x[NegOp-33] + _ = x[ExpOp-34] + _ = x[Expm1Op-35] + _ = x[FloorOp-36] + _ = x[CeilOp-37] + _ = x[RoundOp-38] + _ = x[LogOp-39] + _ = x[Log1pOp-40] + _ = x[LogicalNotOp-41] + _ = x[LogisticOp-42] + _ = x[SignOp-43] + _ = x[ClzOp-44] + _ = x[CosOp-45] + _ = x[SinOp-46] + _ = x[TanhOp-47] + _ = x[SqrtOp-48] + _ = x[RsqrtOp-49] + _ = x[ImagOp-50] + _ = x[RealOp-51] + _ = x[ConjOp-52] + _ = x[AddOp-53] + _ = x[MulOp-54] + _ = x[SubOp-55] + _ = x[DivOp-56] + _ = x[RemOp-57] + _ = x[AndOp-58] + _ = x[OrOp-59] + _ = x[XorOp-60] + _ = x[DotOp-61] + _ = x[MinOp-62] + _ = x[MaxOp-63] + _ = x[PowOp-64] + _ = x[ComplexOp-65] + _ = x[EqualOp-66] + _ = x[NotEqualOp-67] + _ = x[GreaterOrEqualOp-68] + _ = x[GreaterThanOp-69] + _ = x[LessOrEqualOp-70] + _ = x[LessThanOp-71] + _ = x[EqualTotalOrderOp-72] + _ = x[NotEqualTotalOrderOp-73] + _ = x[GreaterOrEqualTotalOrderOp-74] + _ = x[GreaterThanTotalOrderOp-75] + _ = x[LessOrEqualTotalOrderOp-76] + _ = x[LessThanTotalOrderOp-77] } -const _OpType_name = "InvalidOpParameterOpIotaOpConstantOpIdentityOpConvertDTypeOpWhereOpTupleOpGetTupleElementOpReshapeOpBroadcastOpBroadcastInDimOpTransposeOpCallOpReduceOpReduceWindowOpConcatenateOpSliceOpArgMinMaxOpPadOpGatherOpScatterOpSelectAndScatterOpConvGeneralDilatedOpReverseOpDotGeneralOpFftOpBatchNormTrainingOpBatchNormInferenceOpBatchNormGradOpRngBitGeneratorOpAbsOpNegOpExpOpExpm1OpFloorOpCeilOpRoundOpLogOpLog1pOpLogicalNotOpLogisticOpSignOpClzOpCosOpSinOpTanhOpSqrtOpRsqrtOpImagOpRealOpConjOpAddOpMulOpSubOpDivOpRemOpAndOpOrOpXorOpDotOpMinOpMaxOpPowOpComplexOpEqualOpNotEqualOpGreaterOrEqualOpGreaterThanOpLessOrEqualOpLessThanOpEqualTotalOrderOpNotEqualTotalOrderOpGreaterOrEqualTotalOrderOpGreaterThanTotalOrderOpLessOrEqualTotalOrderOpLessThanTotalOrderOp" +const _OpType_name = "InvalidOpParameterOpIotaOpConstantOpIdentityOpConvertDTypeOpWhereOpTupleOpGetTupleElementOpReshapeOpBroadcastOpBroadcastInDimOpTransposeOpCallOpReduceOpReduceWindowOpConcatenateOpSliceOpArgMinMaxOpPadOpGatherOpScatterOpSelectAndScatterOpConvGeneralDilatedOpReverseOpDotGeneralOpFftOpBatchNormTrainingOpBatchNormInferenceOpBatchNormGradOpRngBitGeneratorOpWhileOpAbsOpNegOpExpOpExpm1OpFloorOpCeilOpRoundOpLogOpLog1pOpLogicalNotOpLogisticOpSignOpClzOpCosOpSinOpTanhOpSqrtOpRsqrtOpImagOpRealOpConjOpAddOpMulOpSubOpDivOpRemOpAndOpOrOpXorOpDotOpMinOpMaxOpPowOpComplexOpEqualOpNotEqualOpGreaterOrEqualOpGreaterThanOpLessOrEqualOpLessThanOpEqualTotalOrderOpNotEqualTotalOrderOpGreaterOrEqualTotalOrderOpGreaterThanTotalOrderOpLessOrEqualTotalOrderOpLessThanTotalOrderOp" -var _OpType_index = [...]uint16{0, 9, 20, 26, 36, 46, 60, 67, 74, 91, 100, 111, 127, 138, 144, 152, 166, 179, 186, 197, 202, 210, 219, 237, 257, 266, 278, 283, 302, 322, 337, 354, 359, 364, 369, 376, 383, 389, 396, 401, 408, 420, 430, 436, 441, 446, 451, 457, 463, 470, 476, 482, 488, 493, 498, 503, 508, 513, 518, 522, 527, 532, 537, 542, 547, 556, 563, 573, 589, 602, 615, 625, 642, 662, 688, 711, 734, 754} +var _OpType_index = [...]uint16{0, 9, 20, 26, 36, 46, 60, 67, 74, 91, 100, 111, 127, 138, 144, 152, 166, 179, 186, 197, 202, 210, 219, 237, 257, 266, 278, 283, 302, 322, 337, 354, 361, 366, 371, 376, 383, 390, 396, 403, 408, 415, 427, 437, 443, 448, 453, 458, 464, 470, 477, 483, 489, 495, 500, 505, 510, 515, 520, 525, 529, 534, 539, 544, 549, 554, 563, 570, 580, 596, 609, 622, 632, 649, 669, 695, 718, 741, 761} func (i OpType) String() string { if i < 0 || i >= OpType(len(_OpType_index)-1) { diff --git a/xlabuilder/shape.go b/xlabuilder/shape.go index 4363743..3ed4bd1 100644 --- a/xlabuilder/shape.go +++ b/xlabuilder/shape.go @@ -61,6 +61,21 @@ func (s Shape) Size() int { return size } +// Clone makes a deep copy (including dimensions and tuples) of the given shape. +func (s Shape) Clone() (newS Shape) { + newS.DType = s.DType + if len(s.Dimensions) > 0 { + newS.Dimensions = slices.Clone(s.Dimensions) + } + if len(s.TupleShapes) > 0 { + newS.TupleShapes = make([]Shape, len(s.TupleShapes)) + for ii, subS := range s.TupleShapes { + newS.TupleShapes[ii] = subS.Clone() + } + } + return newS +} + // TupleSize is an alias to len(Shape.TupleShapes). func (s Shape) TupleSize() int { return len(s.TupleShapes) diff --git a/xlabuilder/special_ops.go b/xlabuilder/special_ops.go index 4318720..bc27dd0 100644 --- a/xlabuilder/special_ops.go +++ b/xlabuilder/special_ops.go @@ -1269,3 +1269,34 @@ func DecodeRngBitGenerator(op *Op) (state *Op, shape Shape) { shape = op.ShapeArg return } + +// While executes a loop in the computation. +// +// It takes as input: +// +// - initialState: usually a tuple, that includes all variables used by condition and body. +// - condition: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs +// a bool (dtypes.PRED) whether the loop should keep iterating. +// - body: a sub-computation (see XlaBuilder.CreateSubBuilder) takes the current state as input and outputs +// an updated state. +// +// See details in https://openxla.org/xla/operation_semantics#while +func While(initialState *Op, condition, body *XlaComputation) (*Op, error) { + builder := initialState.builder + op := newOp(WhileOp, initialState) + op.ComputationArg = condition + op.SecondComputationArg = body + err := builder.addOp(op) + if err != nil { + return nil, err + } + return op, nil +} + +// DecodeWhile retrieves the arguments for the While op. +func DecodeWhile(op *Op) (initialState *Op, condition, body *XlaComputation) { + initialState = op.OpInputs[0] + condition = op.ComputationArg + body = op.SecondComputationArg + return +} diff --git a/xlabuilder/special_ops_test.go b/xlabuilder/special_ops_test.go index d572b01..db31c1f 100644 --- a/xlabuilder/special_ops_test.go +++ b/xlabuilder/special_ops_test.go @@ -563,3 +563,57 @@ func TestReverse(t *testing.T) { builder.Destroy() } } + +func TestWhile(t *testing.T) { + client := getPJRTClient(t) + + builder := New(t.Name()) + dtype := dtypes.Int64 + x := capture(Parameter(builder, "x", 0, MakeShape(dtype))).Test(t) + value := capture(ScalarOne(builder, dtype)).Test(t) + + // Calcualte factorial using a While loop. + // While loop: + // - initialState: (x, 1) + initialState := capture(Tuple(x, value)).Test(t) + var cond, body *XlaComputation + // - condition: x > 0 + { + condBuilder := builder.CreateSubBuilder(t.Name() + "_condition") + tuple := capture(Parameter(condBuilder, "tuple", 0, initialState.Shape.Clone())).Test(t) + loopX := capture(GetTupleElement(tuple, 0)).Test(t) + zero := capture(ScalarZero(condBuilder, dtype)).Test(t) + output := capture(GreaterThan(loopX, zero)).Test(t) + cond = capture(condBuilder.Build(output)).Test(t) + } + // - body: value = value * x; x = x-1; + { + bodyBuilder := builder.CreateSubBuilder(t.Name() + "_body") + tuple := capture(Parameter(bodyBuilder, "tuple", 0, initialState.Shape.Clone())).Test(t) + loopX := capture(GetTupleElement(tuple, 0)).Test(t) + loopValue := capture(GetTupleElement(tuple, 1)).Test(t) + loopValue = capture(Mul(loopValue, loopX)).Test(t) + one := capture(ScalarOne(bodyBuilder, dtype)).Test(t) + loopX = capture(Sub(loopX, one)).Test(t) + output := capture(Tuple(loopX, loopValue)).Test(t) + body = capture(bodyBuilder.Build(output)).Test(t) + } + state := capture(While(initialState, cond, body)).Test(t) + + gotInitialState, gotCond, gotBody := DecodeWhile(state) + require.Same(t, initialState, gotInitialState) + require.Same(t, cond, gotCond) + require.Same(t, body, gotBody) + + output := capture(GetTupleElement(state, 1)).Test(t) + exec := compile(t, client, capture(builder.Build(output)).Test(t)) + + // 5! = 120 + got := int(execWithScalars(t, client, exec, int64(5))) + require.Equal(t, 120, got) + + // 7! = 5040 + got = int(execWithScalars(t, client, exec, int64(7))) + require.Equal(t, 5040, got) + builder.Destroy() +}