diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 2e52eee..53e5e03 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -2,6 +2,7 @@ * Moved some `dtypes` support functionality from GoMLX to Gopjrt. * Added BFloat16 alias. +* Renamed `FromGoType` to `FromGenericsType` and `FromType` to `FromGoType`, to maintain naming consistency. # v0.1.2 SuppressAbseilLoggingHack diff --git a/dtypes/dtypes.go b/dtypes/dtypes.go index 6149c57..f75e91d 100644 --- a/dtypes/dtypes.go +++ b/dtypes/dtypes.go @@ -12,8 +12,8 @@ import ( // Generate automatic C-to-Go boilerplate code for pjrt_c_api.h. //go:generate go run ../cmd/dtypes_codegen -// FromGoType returns the DType enum for the given type that this package knows about. -func FromGoType[T Supported]() DType { +// FromGenericsType returns the DType enum for the given type that this package knows about. +func FromGenericsType[T Supported]() DType { var t T switch (any(t)).(type) { case float64: @@ -57,9 +57,9 @@ func FromGoType[T Supported]() DType { return InvalidDType } -// FromType returns the DType for the given [reflect.Type]. +// FromGoType returns the DType for the given [reflect.Type]. // It panics for unknown DType values. -func FromType(t reflect.Type) DType { +func FromGoType(t reflect.Type) DType { if t == float16Type { return Float16 } @@ -112,7 +112,7 @@ func FromType(t reflect.Type) DType { // FromAny introspects the underlying type of any and return the corresponding DType. // Non-scalar types, or not supported types returns a InvalidType. func FromAny(value any) DType { - return FromType(reflect.TypeOf(value)) + return FromGoType(reflect.TypeOf(value)) } // Size returns the number of bytes for the given DType. diff --git a/pjrt/buffers.go b/pjrt/buffers.go index a29c491..8cbbbc7 100644 --- a/pjrt/buffers.go +++ b/pjrt/buffers.go @@ -218,7 +218,7 @@ func FlatDataToRawWithDimensions[T dtypes.Supported](flat []T, dimensions ...int exceptions.Panicf("FlatDataToRawWithDimensions given a flat slice of size %d that doesn't match dimensions %v (total size %d)", len(flat), dimensions, expectedSize) } - dtype := dtypes.FromGoType[T]() + dtype := dtypes.FromGenericsType[T]() if len(flat) == 0 { return nil, dtype, dimensions } @@ -230,7 +230,7 @@ func FlatDataToRawWithDimensions[T dtypes.Supported](flat []T, dimensions ...int // ScalarToRaw generates the raw values needed by BufferFromHostConfig.FromRawData to feed a simple scalar value. func ScalarToRaw[T dtypes.Supported](value T) ([]byte, dtypes.DType, []int) { - dtype := dtypes.FromGoType[T]() + dtype := dtypes.FromGenericsType[T]() rawSlice := unsafe.Slice((*byte)(unsafe.Pointer(&value)), int(unsafe.Sizeof(value))) return rawSlice, dtype, nil // empty dimensions for scalar } @@ -309,7 +309,7 @@ func ScalarToBuffer[T dtypes.Supported](client *Client, value T) (b *Buffer, err pinner.Pin(&value) defer pinner.Unpin() - dtype := dtypes.FromGoType[T]() + dtype := dtypes.FromGenericsType[T]() src := unsafe.Slice((*byte)(unsafe.Pointer(&value)), unsafe.Sizeof(value)) return client.BufferFromHost().FromRawData(src, dtype, nil).Done() } @@ -330,7 +330,7 @@ func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensio if err != nil { return } - requestedDType := dtypes.FromGoType[T]() + requestedDType := dtypes.FromGenericsType[T]() if dtype != requestedDType { var dummy T err = errors.Errorf("called BufferToArray[%T](...), but underlying buffer has dtype %s", dummy, dtype) diff --git a/xlabuilder/literal.go b/xlabuilder/literal.go index 70633ca..01797f6 100644 --- a/xlabuilder/literal.go +++ b/xlabuilder/literal.go @@ -39,7 +39,7 @@ func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) *Literal { if len(dimensions) == 0 { dimensions = []int{len(flat)} } - shape := MakeShape(dtypes.FromGoType[T](), dimensions...) + shape := MakeShape(dtypes.FromGenericsType[T](), dimensions...) if shape.Size() != len(flat) { exceptions.Panicf("NewArrayLiteral got a slice of length %d, but the shape %s given has %d elements", len(flat), shape, shape.Size()) @@ -52,7 +52,7 @@ func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) *Literal { // NewScalarLiteral creates a scalar Literal initialized with the given value. func NewScalarLiteral[T dtypes.Supported](value T) *Literal { - shape := MakeShape(dtypes.FromGoType[T]()) + shape := MakeShape(dtypes.FromGenericsType[T]()) l := NewLiteralFromShape(shape) *(*T)(unsafe.Pointer(l.cLiteral.data)) = value return l @@ -88,7 +88,7 @@ func NewScalarLiteralFromFloat64(value float64, dtype dtypes.DType) *Literal { // It uses reflection to inspect the type. func NewScalarLiteralFromAny(value any) *Literal { valueOf := reflect.ValueOf(value) - dtype := dtypes.FromType(valueOf.Type()) + dtype := dtypes.FromGoType(valueOf.Type()) l := NewLiteralFromShape(MakeShape(dtype)) lValueOf := reflect.NewAt(dtype.GoType(), unsafe.Pointer(l.cLiteral.data)).Elem() lValueOf.Set(valueOf)