diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 79598df..c7f5a43 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,7 @@ +# Next + +* Issue #21: use syscall.Dup3 instead of syscall.Dup2 for Arm64 compatibility. + # v0.5.0 - 2024/12/19 - Adding direct access to PJRT buffers for CPU. * Added `install_linux_amd64_amazonlinux.sh` and pre-built libraries for amazonlinux (built using old glibc support). diff --git a/pjrt/dynamiclib.go b/pjrt/dynamiclib.go index e196858..a5c12fe 100644 --- a/pjrt/dynamiclib.go +++ b/pjrt/dynamiclib.go @@ -266,7 +266,7 @@ func SuppressAbseilLoggingHack(fn func()) { } else { defer func() { // Revert suppression: revert back newFd to 2 - err := syscall.Dup2(newFd, 2) + err := syscall.Dup3(newFd, 2, 0) if err != nil { klog.Errorf("Failed sycall.Dup2 while reverting suppression of logging: %v", err) } diff --git a/xlabuilder/reduce_test.go b/xlabuilder/reduce_test.go index 335c2cc..958a067 100644 --- a/xlabuilder/reduce_test.go +++ b/xlabuilder/reduce_test.go @@ -5,9 +5,38 @@ import ( "github.com/gomlx/gopjrt/dtypes" . "github.com/gomlx/gopjrt/xlabuilder" "github.com/stretchr/testify/require" + "math" "testing" ) +// TestMax tests the Max function, as part of the ReduceMax test. +// See https://github.com/openxla/xla/issues/21461 +func TestMax(t *testing.T) { + client := getPJRTClient(t) + { + builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Constant", t.Name())) + input0 := capture(Constant(builder, NewScalarLiteral(math.NaN()))).Test(t) + input1 := capture(Constant(builder, NewScalarLiteral(1.0))).Test(t) + output := capture(Max(input0, input1)).Test(t) + exec := compile(t, client, capture(builder.Build(output)).Test(t)) + got := execScalarOutput[float64](t, client, exec) + require.True(t, math.IsNaN(got)) + builder.Destroy() + } + { + builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Parameter", t.Name())) + input0 := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float64))).Test(t) + input1 := capture(Parameter(builder, "y", 1, MakeShape(dtypes.Float64))).Test(t) + input0 = capture(Sqrt(input0)).Test(t) + input1 = capture(Sqrt(input1)).Test(t) + output := capture(Max(input0, input1)).Test(t) + exec := compile(t, client, capture(builder.Build(output)).Test(t)) + got := execWithScalars(t, client, exec, -1.0, 1.0) + require.True(t, math.IsNaN(got)) + builder.Destroy() + } +} + func TestReduce(t *testing.T) { client := getPJRTClient(t) @@ -33,6 +62,35 @@ func TestReduce(t *testing.T) { builder.Destroy() } + { + builder := New(fmt.Sprintf("%s-ReduceMax with NaN as constant", t.Name())) + literal := capture(NewArrayLiteral([]float32{float32(math.NaN()), 1}, 2)).Test(t) + input := capture(Constant(builder, literal)).Test(t) + output := capture(ReduceMax(input, 0)).Test(t) + comp := capture(builder.Build(output)).Test(t) + fmt.Printf("HLO:\n%s\n", comp.TextHLO()) + exec := compile(t, client, comp) + got := execWithScalars[float32](t, client, exec) + require.True(t, math.IsNaN(float64(got))) + builder.Destroy() + } + + { + builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name())) + input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t) + output := capture(ReduceMax(input, 0)).Test(t) + comp := capture(builder.Build(output)).Test(t) + fmt.Printf("HLO:\n%s\n", comp.TextHLO()) + exec := compile(t, client, comp) + got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1}) + require.Empty(t, dims) + fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0]) + // TODO: re-enable this test when bug is fixed. + // See https://github.com/openxla/xla/issues/21461 + // require.True(t, math.IsNaN(float64(got[0]))) + builder.Destroy() + } + // Test with ReduceSum and ReduceProduct { builder := New(fmt.Sprintf("%s-ReduceProduct-ReduceSum", t.Name())) @@ -99,6 +157,25 @@ func TestReduce(t *testing.T) { } } +func TestReduceMaxBuggy(t *testing.T) { + client := getPJRTClient(t) + { + builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name())) + input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t) + output := capture(ReduceMax(input, 0)).Test(t) + comp := capture(builder.Build(output)).Test(t) + fmt.Printf("HLO:\n%s\n", comp.TextHLO()) + exec := compile(t, client, comp) + got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1}) + require.Empty(t, dims) + fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0]) + // TODO: re-enable this test when bug is fixed. + // See https://github.com/openxla/xla/issues/21461 + //require.True(t, math.IsNaN(float64(got[0]))) + builder.Destroy() + } +} + func TestReduceWindow(t *testing.T) { client := getPJRTClient(t) builder := New(t.Name()) diff --git a/xlabuilder/xlabuilder_test.go b/xlabuilder/xlabuilder_test.go index 31f5cef..c768fa7 100644 --- a/xlabuilder/xlabuilder_test.go +++ b/xlabuilder/xlabuilder_test.go @@ -6,6 +6,7 @@ import ( "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/gopjrt/pjrt" . "github.com/gomlx/gopjrt/xlabuilder" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/klog/v2" "os" @@ -73,20 +74,30 @@ func compile(t *testing.T, client *pjrt.Client, comp *XlaComputation) (exec *pjr // execWithScalars executes the program on the input value given, and return the output. // Both input and output expected to be a scalar. // Any errors fail the test. -func execWithScalars[T dtypes.Supported](t *testing.T, client *pjrt.Client, exec *pjrt.LoadedExecutable, input T) T { - inputBuffer, err := pjrt.ScalarToBuffer(client, input) - require.NoErrorf(t, err, "Failed to create on-device buffer for input %v", input) - defer func() { require.NoError(t, inputBuffer.Destroy()) }() +func execWithScalars[T dtypes.Supported](t *testing.T, client *pjrt.Client, exec *pjrt.LoadedExecutable, inputs ...T) T { + inputBuffers := make([]*pjrt.Buffer, len(inputs)) + defer func() { + for _, buf := range inputBuffers { + if buf != nil { + assert.NoError(t, buf.Destroy()) + } + } + }() + var err error + for ii, input := range inputs { + inputBuffers[ii], err = pjrt.ScalarToBuffer(client, input) + require.NoErrorf(t, err, "Failed to create on-device buffer for input %v", input) + } - outputBuffers, err := exec.Execute(inputBuffer).Done() - require.NoErrorf(t, err, "Failed to execute on input %v", input) + outputBuffers, err := exec.Execute(inputBuffers...).Done() + require.NoErrorf(t, err, "Failed to execute on inputs %v", inputs) require.Len(t, outputBuffers, 1, "Expected only one output") defer func() { require.NoError(t, outputBuffers[0].Destroy()) }() // Transfer output on-device buffer to a "host" value (in Go). output, err := pjrt.BufferToScalar[T](outputBuffers[0]) - fmt.Printf(" > f(%v)=%v\n", input, output) - require.NoErrorf(t, err, "Failed to transfer results of %q execution on input %d", exec.Name, input) + fmt.Printf(" > f(%v)=%v\n", inputs, output) + require.NoErrorf(t, err, "Failed to transfer results of %q execution on inputs %v", exec.Name, inputs) return output }