Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #21 #22

Merged
merged 3 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Next

* Issue #21: use syscall.Dup3 instead of syscall.Dup2 for Arm64 compatibility.

# v0.5.0 - 2024/12/19 - Adding direct access to PJRT buffers for CPU.

* Added `install_linux_amd64_amazonlinux.sh` and pre-built libraries for amazonlinux (built using old glibc support).
Expand Down
2 changes: 1 addition & 1 deletion pjrt/dynamiclib.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func SuppressAbseilLoggingHack(fn func()) {
} else {
defer func() {
// Revert suppression: revert back newFd to 2
err := syscall.Dup2(newFd, 2)
err := syscall.Dup3(newFd, 2, 0)
if err != nil {
klog.Errorf("Failed sycall.Dup2 while reverting suppression of logging: %v", err)
}
Expand Down
77 changes: 77 additions & 0 deletions xlabuilder/reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,38 @@ import (
"github.com/gomlx/gopjrt/dtypes"
. "github.com/gomlx/gopjrt/xlabuilder"
"github.com/stretchr/testify/require"
"math"
"testing"
)

// TestMax tests the Max function, as part of the ReduceMax test.
// See https://github.com/openxla/xla/issues/21461
func TestMax(t *testing.T) {
client := getPJRTClient(t)
{
builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Constant", t.Name()))
input0 := capture(Constant(builder, NewScalarLiteral(math.NaN()))).Test(t)
input1 := capture(Constant(builder, NewScalarLiteral(1.0))).Test(t)
output := capture(Max(input0, input1)).Test(t)
exec := compile(t, client, capture(builder.Build(output)).Test(t))
got := execScalarOutput[float64](t, client, exec)
require.True(t, math.IsNaN(got))
builder.Destroy()
}
{
builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Parameter", t.Name()))
input0 := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float64))).Test(t)
input1 := capture(Parameter(builder, "y", 1, MakeShape(dtypes.Float64))).Test(t)
input0 = capture(Sqrt(input0)).Test(t)
input1 = capture(Sqrt(input1)).Test(t)
output := capture(Max(input0, input1)).Test(t)
exec := compile(t, client, capture(builder.Build(output)).Test(t))
got := execWithScalars(t, client, exec, -1.0, 1.0)
require.True(t, math.IsNaN(got))
builder.Destroy()
}
}

func TestReduce(t *testing.T) {
client := getPJRTClient(t)

Expand All @@ -33,6 +62,35 @@ func TestReduce(t *testing.T) {
builder.Destroy()
}

{
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as constant", t.Name()))
literal := capture(NewArrayLiteral([]float32{float32(math.NaN()), 1}, 2)).Test(t)
input := capture(Constant(builder, literal)).Test(t)
output := capture(ReduceMax(input, 0)).Test(t)
comp := capture(builder.Build(output)).Test(t)
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
exec := compile(t, client, comp)
got := execWithScalars[float32](t, client, exec)
require.True(t, math.IsNaN(float64(got)))
builder.Destroy()
}

{
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name()))
input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t)
output := capture(ReduceMax(input, 0)).Test(t)
comp := capture(builder.Build(output)).Test(t)
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
exec := compile(t, client, comp)
got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1})
require.Empty(t, dims)
fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0])
// TODO: re-enable this test when bug is fixed.
// See https://github.com/openxla/xla/issues/21461
// require.True(t, math.IsNaN(float64(got[0])))
builder.Destroy()
}

// Test with ReduceSum and ReduceProduct
{
builder := New(fmt.Sprintf("%s-ReduceProduct-ReduceSum", t.Name()))
Expand Down Expand Up @@ -99,6 +157,25 @@ func TestReduce(t *testing.T) {
}
}

func TestReduceMaxBuggy(t *testing.T) {
client := getPJRTClient(t)
{
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name()))
input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t)
output := capture(ReduceMax(input, 0)).Test(t)
comp := capture(builder.Build(output)).Test(t)
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
exec := compile(t, client, comp)
got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1})
require.Empty(t, dims)
fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0])
// TODO: re-enable this test when bug is fixed.
// See https://github.com/openxla/xla/issues/21461
//require.True(t, math.IsNaN(float64(got[0])))
builder.Destroy()
}
}

func TestReduceWindow(t *testing.T) {
client := getPJRTClient(t)
builder := New(t.Name())
Expand Down
27 changes: 19 additions & 8 deletions xlabuilder/xlabuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/gomlx/gopjrt/dtypes"
"github.com/gomlx/gopjrt/pjrt"
. "github.com/gomlx/gopjrt/xlabuilder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/klog/v2"
"os"
Expand Down Expand Up @@ -73,20 +74,30 @@ func compile(t *testing.T, client *pjrt.Client, comp *XlaComputation) (exec *pjr
// execWithScalars executes the program on the input value given, and return the output.
// Both input and output expected to be a scalar.
// Any errors fail the test.
func execWithScalars[T dtypes.Supported](t *testing.T, client *pjrt.Client, exec *pjrt.LoadedExecutable, input T) T {
inputBuffer, err := pjrt.ScalarToBuffer(client, input)
require.NoErrorf(t, err, "Failed to create on-device buffer for input %v", input)
defer func() { require.NoError(t, inputBuffer.Destroy()) }()
func execWithScalars[T dtypes.Supported](t *testing.T, client *pjrt.Client, exec *pjrt.LoadedExecutable, inputs ...T) T {
inputBuffers := make([]*pjrt.Buffer, len(inputs))
defer func() {
for _, buf := range inputBuffers {
if buf != nil {
assert.NoError(t, buf.Destroy())
}
}
}()
var err error
for ii, input := range inputs {
inputBuffers[ii], err = pjrt.ScalarToBuffer(client, input)
require.NoErrorf(t, err, "Failed to create on-device buffer for input %v", input)
}

outputBuffers, err := exec.Execute(inputBuffer).Done()
require.NoErrorf(t, err, "Failed to execute on input %v", input)
outputBuffers, err := exec.Execute(inputBuffers...).Done()
require.NoErrorf(t, err, "Failed to execute on inputs %v", inputs)
require.Len(t, outputBuffers, 1, "Expected only one output")
defer func() { require.NoError(t, outputBuffers[0].Destroy()) }()

// Transfer output on-device buffer to a "host" value (in Go).
output, err := pjrt.BufferToScalar[T](outputBuffers[0])
fmt.Printf(" > f(%v)=%v\n", input, output)
require.NoErrorf(t, err, "Failed to transfer results of %q execution on input %d", exec.Name, input)
fmt.Printf(" > f(%v)=%v\n", inputs, output)
require.NoErrorf(t, err, "Failed to transfer results of %q execution on inputs %v", exec.Name, inputs)
return output
}

Expand Down