From f62ba33942b2ea199a7c4e6c4fe44f991242e743 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 11 Jul 2024 09:36:07 +0200 Subject: [PATCH] Added `xlabuilder.Shape.Memory` and `xlabuilder.NewArrayLiteralFromAny`. --- docs/CHANGELOG.md | 1 + xlabuilder/literal.go | 35 +++++++++++++++++++++++++++++++++++ xlabuilder/literal_test.go | 8 ++++++++ xlabuilder/shape.go | 6 ++++++ 4 files changed, 50 insertions(+) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 1514173..90d0638 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -13,6 +13,7 @@ * Store client link with Buffer. Added `Buffer.Client` method. * Added `Buffer.Device` and `Client.NumForDevice`. * Properly setting client options for `pjrt.NewClient`. Added test for reading/writing `C.PJRT_NamedValues`. +* Added `xlabuilder.Shape.Memory` and `xlabuilder.NewArrayLiteralFromAny`. # v0.1.2 SuppressAbseilLoggingHack diff --git a/xlabuilder/literal.go b/xlabuilder/literal.go index 01797f6..7743d5f 100644 --- a/xlabuilder/literal.go +++ b/xlabuilder/literal.go @@ -95,6 +95,41 @@ func NewScalarLiteralFromAny(value any) *Literal { return l } +// NewArrayLiteralFromAny creates a scalar Literal with the given dynamically typed flat values and its underlying dimensions. +// If dimensions is omitted, it's assumed to be a 1D tensor with dimension matching the length of flat. +// It uses reflection to inspect the type. +func NewArrayLiteralFromAny(flatAny any, dimensions ...int) *Literal { + flatV := reflect.ValueOf(flatAny) + if flatV.Kind() != reflect.Slice { + exceptions.Panicf("NewArrayLiteralFromAny expects a slice, got %T instead", flatAny) + } + dtype := dtypes.FromGoType(flatV.Type().Elem()) + if dtype == dtypes.InvalidDType { + exceptions.Panicf("NewArrayLiteralFromAny expects a slice of valid DTypes, got %T instead", flatAny) + } + if len(dimensions) == 0 { + dimensions = []int{flatV.Len()} + } + shape := MakeShape(dtype, dimensions...) + if shape.Size() != flatV.Len() { + exceptions.Panicf("NewArrayLiteralFromAny got a slice of length %d, but the shape %s given has %d elements", + flatV.Len(), shape, shape.Size()) + } + + // Copy data as bytes -- to avoid using complex reflected slices. + flatPtr := flatV.Index(0).Addr().UnsafePointer() + var pinner runtime.Pinner + pinner.Pin(flatPtr) + defer pinner.Unpin() + flatData := unsafe.Slice((*byte)(flatPtr), shape.Memory()) + + l := NewLiteralFromShape(shape) + lData := unsafe.Slice((*byte)(unsafe.Pointer(l.cLiteral.data)), shape.Memory()) + + copy(lData, flatData) + return l +} + // newLiteral creates the literal and registers the finalizer. func newLiteral(cLiteral *C.Literal, shape Shape) *Literal { l := &Literal{cLiteral: cLiteral, shape: shape} diff --git a/xlabuilder/literal_test.go b/xlabuilder/literal_test.go index c67a61a..2555f29 100644 --- a/xlabuilder/literal_test.go +++ b/xlabuilder/literal_test.go @@ -17,6 +17,7 @@ func TestLiterals(t *testing.T) { require.NotPanics(t, func() { _ = NewScalarLiteral[complex128](complex(1.0, 0.0)) }) require.NotPanics(t, func() { NewScalarLiteral[int8](0).Destroy() }) require.NotPanics(t, func() { NewArrayLiteral([]float32{1, 2, 3, 4, 5, 6}, 2, 3).Destroy() }) + require.NotPanics(t, func() { NewArrayLiteralFromAny([]float64{1, 2, 3, 4, 5, 6}, 2, 3).Destroy() }) // Check that various literals get correcly interpreted in PRJT. client := getPJRTClient(t) @@ -34,4 +35,11 @@ func TestLiterals(t *testing.T) { output = capture(Constant(builder, NewScalarLiteralFromAny(float16.Fromfloat32(15e-3)))).Test(t) exec = compile(t, client, capture(builder.Build(output)).Test(t)) require.Equal(t, float16.Fromfloat32(15e-3), execScalarOutput[float16.Float16](t, client, exec)) + + builder = New(t.Name()) + output = capture(Constant(builder, NewArrayLiteralFromAny([]float64{1, 3, 5, 7, 11, 13}, 3, 2))).Test(t) + exec = compile(t, client, capture(builder.Build(output)).Test(t)) + gotFlat, gotDims := execArrayOutput[float64](t, client, exec) + require.Equal(t, []int{3, 2}, gotDims) + require.Equal(t, []float64{1, 3, 5, 7, 11, 13}, gotFlat) } diff --git a/xlabuilder/shape.go b/xlabuilder/shape.go index 3ed4bd1..611b9b0 100644 --- a/xlabuilder/shape.go +++ b/xlabuilder/shape.go @@ -61,6 +61,12 @@ func (s Shape) Size() int { return size } +// Memory returns the memory used to store an array of the given shape, the same as the size in bytes. +// Careful, so far all types in Go and on device seem to use the same sizes, but future type this is not guaranteed. +func (s Shape) Memory() uintptr { + return s.DType.Memory() * uintptr(s.Size()) +} + // Clone makes a deep copy (including dimensions and tuples) of the given shape. func (s Shape) Clone() (newS Shape) { newS.DType = s.DType