Skip to content

Commit

Permalink
Merge pull request #22 from gomlx/reducemax
Browse files Browse the repository at this point in the history
Fix issue #21
  • Loading branch information
janpfeifer authored Jan 20, 2025
2 parents 0962204 + 063546c commit 8a21d61
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Next

* Issue #21: use syscall.Dup3 instead of syscall.Dup2 for Arm64 compatibility.

# v0.5.0 - 2024/12/19 - Adding direct access to PJRT buffers for CPU.

* Added `install_linux_amd64_amazonlinux.sh` and pre-built libraries for amazonlinux (built using old glibc support).
Expand Down
2 changes: 1 addition & 1 deletion pjrt/dynamiclib.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func SuppressAbseilLoggingHack(fn func()) {
} else {
defer func() {
// Revert suppression: revert back newFd to 2
err := syscall.Dup2(newFd, 2)
err := syscall.Dup3(newFd, 2, 0)
if err != nil {
klog.Errorf("Failed sycall.Dup2 while reverting suppression of logging: %v", err)
}
Expand Down
77 changes: 77 additions & 0 deletions xlabuilder/reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,38 @@ import (
"github.com/gomlx/gopjrt/dtypes"
. "github.com/gomlx/gopjrt/xlabuilder"
"github.com/stretchr/testify/require"
"math"
"testing"
)

// TestMax tests the Max function, as part of the ReduceMax test.
// See https://github.com/openxla/xla/issues/21461
func TestMax(t *testing.T) {
client := getPJRTClient(t)
{
builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Constant", t.Name()))
input0 := capture(Constant(builder, NewScalarLiteral(math.NaN()))).Test(t)
input1 := capture(Constant(builder, NewScalarLiteral(1.0))).Test(t)
output := capture(Max(input0, input1)).Test(t)
exec := compile(t, client, capture(builder.Build(output)).Test(t))
got := execScalarOutput[float64](t, client, exec)
require.True(t, math.IsNaN(got))
builder.Destroy()
}
{
builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Parameter", t.Name()))
input0 := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float64))).Test(t)
input1 := capture(Parameter(builder, "y", 1, MakeShape(dtypes.Float64))).Test(t)
input0 = capture(Sqrt(input0)).Test(t)
input1 = capture(Sqrt(input1)).Test(t)
output := capture(Max(input0, input1)).Test(t)
exec := compile(t, client, capture(builder.Build(output)).Test(t))
got := execWithScalars(t, client, exec, -1.0, 1.0)
require.True(t, math.IsNaN(got))
builder.Destroy()
}
}

func TestReduce(t *testing.T) {
client := getPJRTClient(t)

Expand All @@ -33,6 +62,35 @@ func TestReduce(t *testing.T) {
builder.Destroy()
}

{
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as constant", t.Name()))
literal := capture(NewArrayLiteral([]float32{float32(math.NaN()), 1}, 2)).Test(t)
input := capture(Constant(builder, literal)).Test(t)
output := capture(ReduceMax(input, 0)).Test(t)
comp := capture(builder.Build(output)).Test(t)
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
exec := compile(t, client, comp)
got := execWithScalars[float32](t, client, exec)
require.True(t, math.IsNaN(float64(got)))
builder.Destroy()
}

{
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name()))
input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t)
output := capture(ReduceMax(input, 0)).Test(t)
comp := capture(builder.Build(output)).Test(t)
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
exec := compile(t, client, comp)
got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1})
require.Empty(t, dims)
fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0])
// TODO: re-enable this test when bug is fixed.
// See https://github.com/openxla/xla/issues/21461
// require.True(t, math.IsNaN(float64(got[0])))
builder.Destroy()
}

// Test with ReduceSum and ReduceProduct
{
builder := New(fmt.Sprintf("%s-ReduceProduct-ReduceSum", t.Name()))
Expand Down Expand Up @@ -99,6 +157,25 @@ func TestReduce(t *testing.T) {
}
}

func TestReduceMaxBuggy(t *testing.T) {
client := getPJRTClient(t)
{
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name()))
input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t)
output := capture(ReduceMax(input, 0)).Test(t)
comp := capture(builder.Build(output)).Test(t)
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
exec := compile(t, client, comp)
got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1})
require.Empty(t, dims)
fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0])
// TODO: re-enable this test when bug is fixed.
// See https://github.com/openxla/xla/issues/21461
//require.True(t, math.IsNaN(float64(got[0])))
builder.Destroy()
}
}

func TestReduceWindow(t *testing.T) {
client := getPJRTClient(t)
builder := New(t.Name())
Expand Down
27 changes: 19 additions & 8 deletions xlabuilder/xlabuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/gomlx/gopjrt/dtypes"
"github.com/gomlx/gopjrt/pjrt"
. "github.com/gomlx/gopjrt/xlabuilder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/klog/v2"
"os"
Expand Down Expand Up @@ -73,20 +74,30 @@ func compile(t *testing.T, client *pjrt.Client, comp *XlaComputation) (exec *pjr
// execWithScalars executes the program on the input value given, and return the output.
// Both input and output expected to be a scalar.
// Any errors fail the test.
func execWithScalars[T dtypes.Supported](t *testing.T, client *pjrt.Client, exec *pjrt.LoadedExecutable, input T) T {
inputBuffer, err := pjrt.ScalarToBuffer(client, input)
require.NoErrorf(t, err, "Failed to create on-device buffer for input %v", input)
defer func() { require.NoError(t, inputBuffer.Destroy()) }()
func execWithScalars[T dtypes.Supported](t *testing.T, client *pjrt.Client, exec *pjrt.LoadedExecutable, inputs ...T) T {
inputBuffers := make([]*pjrt.Buffer, len(inputs))
defer func() {
for _, buf := range inputBuffers {
if buf != nil {
assert.NoError(t, buf.Destroy())
}
}
}()
var err error
for ii, input := range inputs {
inputBuffers[ii], err = pjrt.ScalarToBuffer(client, input)
require.NoErrorf(t, err, "Failed to create on-device buffer for input %v", input)
}

outputBuffers, err := exec.Execute(inputBuffer).Done()
require.NoErrorf(t, err, "Failed to execute on input %v", input)
outputBuffers, err := exec.Execute(inputBuffers...).Done()
require.NoErrorf(t, err, "Failed to execute on inputs %v", inputs)
require.Len(t, outputBuffers, 1, "Expected only one output")
defer func() { require.NoError(t, outputBuffers[0].Destroy()) }()

// Transfer output on-device buffer to a "host" value (in Go).
output, err := pjrt.BufferToScalar[T](outputBuffers[0])
fmt.Printf(" > f(%v)=%v\n", input, output)
require.NoErrorf(t, err, "Failed to transfer results of %q execution on input %d", exec.Name, input)
fmt.Printf(" > f(%v)=%v\n", inputs, output)
require.NoErrorf(t, err, "Failed to transfer results of %q execution on inputs %v", exec.Name, inputs)
return output
}

Expand Down

0 comments on commit 8a21d61

Please sign in to comment.