Skip to content

Commit

Permalink
Added xlabuilder.Shape.Memory and xlabuilder.NewArrayLiteralFromAny.
Browse files Browse the repository at this point in the history
  • Loading branch information
janpfeifer committed Jul 11, 2024
1 parent b1e4d05 commit f62ba33
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Store client link with Buffer. Added `Buffer.Client` method.
* Added `Buffer.Device` and `Client.NumForDevice`.
* Properly setting client options for `pjrt.NewClient`. Added test for reading/writing `C.PJRT_NamedValues`.
* Added `xlabuilder.Shape.Memory` and `xlabuilder.NewArrayLiteralFromAny`.

# v0.1.2 SuppressAbseilLoggingHack

Expand Down
35 changes: 35 additions & 0 deletions xlabuilder/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,41 @@ func NewScalarLiteralFromAny(value any) *Literal {
return l
}

// NewArrayLiteralFromAny creates a scalar Literal with the given dynamically typed flat values and its underlying dimensions.
// If dimensions is omitted, it's assumed to be a 1D tensor with dimension matching the length of flat.
// It uses reflection to inspect the type.
func NewArrayLiteralFromAny(flatAny any, dimensions ...int) *Literal {
flatV := reflect.ValueOf(flatAny)
if flatV.Kind() != reflect.Slice {
exceptions.Panicf("NewArrayLiteralFromAny expects a slice, got %T instead", flatAny)
}
dtype := dtypes.FromGoType(flatV.Type().Elem())
if dtype == dtypes.InvalidDType {
exceptions.Panicf("NewArrayLiteralFromAny expects a slice of valid DTypes, got %T instead", flatAny)
}
if len(dimensions) == 0 {
dimensions = []int{flatV.Len()}
}
shape := MakeShape(dtype, dimensions...)
if shape.Size() != flatV.Len() {
exceptions.Panicf("NewArrayLiteralFromAny got a slice of length %d, but the shape %s given has %d elements",
flatV.Len(), shape, shape.Size())
}

// Copy data as bytes -- to avoid using complex reflected slices.
flatPtr := flatV.Index(0).Addr().UnsafePointer()
var pinner runtime.Pinner
pinner.Pin(flatPtr)
defer pinner.Unpin()
flatData := unsafe.Slice((*byte)(flatPtr), shape.Memory())

l := NewLiteralFromShape(shape)
lData := unsafe.Slice((*byte)(unsafe.Pointer(l.cLiteral.data)), shape.Memory())

copy(lData, flatData)
return l
}

// newLiteral creates the literal and registers the finalizer.
func newLiteral(cLiteral *C.Literal, shape Shape) *Literal {
l := &Literal{cLiteral: cLiteral, shape: shape}
Expand Down
8 changes: 8 additions & 0 deletions xlabuilder/literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func TestLiterals(t *testing.T) {
require.NotPanics(t, func() { _ = NewScalarLiteral[complex128](complex(1.0, 0.0)) })
require.NotPanics(t, func() { NewScalarLiteral[int8](0).Destroy() })
require.NotPanics(t, func() { NewArrayLiteral([]float32{1, 2, 3, 4, 5, 6}, 2, 3).Destroy() })
require.NotPanics(t, func() { NewArrayLiteralFromAny([]float64{1, 2, 3, 4, 5, 6}, 2, 3).Destroy() })

// Check that various literals get correcly interpreted in PRJT.
client := getPJRTClient(t)
Expand All @@ -34,4 +35,11 @@ func TestLiterals(t *testing.T) {
output = capture(Constant(builder, NewScalarLiteralFromAny(float16.Fromfloat32(15e-3)))).Test(t)
exec = compile(t, client, capture(builder.Build(output)).Test(t))
require.Equal(t, float16.Fromfloat32(15e-3), execScalarOutput[float16.Float16](t, client, exec))

builder = New(t.Name())
output = capture(Constant(builder, NewArrayLiteralFromAny([]float64{1, 3, 5, 7, 11, 13}, 3, 2))).Test(t)
exec = compile(t, client, capture(builder.Build(output)).Test(t))
gotFlat, gotDims := execArrayOutput[float64](t, client, exec)
require.Equal(t, []int{3, 2}, gotDims)
require.Equal(t, []float64{1, 3, 5, 7, 11, 13}, gotFlat)
}
6 changes: 6 additions & 0 deletions xlabuilder/shape.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ func (s Shape) Size() int {
return size
}

// Memory returns the memory used to store an array of the given shape, the same as the size in bytes.
// Careful, so far all types in Go and on device seem to use the same sizes, but future type this is not guaranteed.
func (s Shape) Memory() uintptr {
return s.DType.Memory() * uintptr(s.Size())
}

// Clone makes a deep copy (including dimensions and tuples) of the given shape.
func (s Shape) Clone() (newS Shape) {
newS.DType = s.DType
Expand Down

0 comments on commit f62ba33

Please sign in to comment.