Skip to content

Commit

Permalink
Renamed FromGoType to FromGenericsType and FromType to `FromGoT…
Browse files Browse the repository at this point in the history
…ype`, to maintain naming consistency.
  • Loading branch information
janpfeifer committed Jul 6, 2024
1 parent 9502c13 commit bbd150e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* Moved some `dtypes` support functionality from GoMLX to Gopjrt.
* Added BFloat16 alias.
* Renamed `FromGoType` to `FromGenericsType` and `FromType` to `FromGoType`, to maintain naming consistency.

# v0.1.2 SuppressAbseilLoggingHack

Expand Down
10 changes: 5 additions & 5 deletions dtypes/dtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
// Generate automatic C-to-Go boilerplate code for pjrt_c_api.h.
//go:generate go run ../cmd/dtypes_codegen

// FromGoType returns the DType enum for the given type that this package knows about.
func FromGoType[T Supported]() DType {
// FromGenericsType returns the DType enum for the given type that this package knows about.
func FromGenericsType[T Supported]() DType {
var t T
switch (any(t)).(type) {
case float64:
Expand Down Expand Up @@ -57,9 +57,9 @@ func FromGoType[T Supported]() DType {
return InvalidDType
}

// FromType returns the DType for the given [reflect.Type].
// FromGoType returns the DType for the given [reflect.Type].
// It panics for unknown DType values.
func FromType(t reflect.Type) DType {
func FromGoType(t reflect.Type) DType {
if t == float16Type {
return Float16
}
Expand Down Expand Up @@ -112,7 +112,7 @@ func FromType(t reflect.Type) DType {
// FromAny introspects the underlying type of any and return the corresponding DType.
// Non-scalar types, or not supported types returns a InvalidType.
func FromAny(value any) DType {
return FromType(reflect.TypeOf(value))
return FromGoType(reflect.TypeOf(value))
}

// Size returns the number of bytes for the given DType.
Expand Down
8 changes: 4 additions & 4 deletions pjrt/buffers.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func FlatDataToRawWithDimensions[T dtypes.Supported](flat []T, dimensions ...int
exceptions.Panicf("FlatDataToRawWithDimensions given a flat slice of size %d that doesn't match dimensions %v (total size %d)",
len(flat), dimensions, expectedSize)
}
dtype := dtypes.FromGoType[T]()
dtype := dtypes.FromGenericsType[T]()
if len(flat) == 0 {
return nil, dtype, dimensions
}
Expand All @@ -230,7 +230,7 @@ func FlatDataToRawWithDimensions[T dtypes.Supported](flat []T, dimensions ...int

// ScalarToRaw generates the raw values needed by BufferFromHostConfig.FromRawData to feed a simple scalar value.
func ScalarToRaw[T dtypes.Supported](value T) ([]byte, dtypes.DType, []int) {
dtype := dtypes.FromGoType[T]()
dtype := dtypes.FromGenericsType[T]()
rawSlice := unsafe.Slice((*byte)(unsafe.Pointer(&value)), int(unsafe.Sizeof(value)))
return rawSlice, dtype, nil // empty dimensions for scalar
}
Expand Down Expand Up @@ -309,7 +309,7 @@ func ScalarToBuffer[T dtypes.Supported](client *Client, value T) (b *Buffer, err
pinner.Pin(&value)
defer pinner.Unpin()

dtype := dtypes.FromGoType[T]()
dtype := dtypes.FromGenericsType[T]()
src := unsafe.Slice((*byte)(unsafe.Pointer(&value)), unsafe.Sizeof(value))
return client.BufferFromHost().FromRawData(src, dtype, nil).Done()
}
Expand All @@ -330,7 +330,7 @@ func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensio
if err != nil {
return
}
requestedDType := dtypes.FromGoType[T]()
requestedDType := dtypes.FromGenericsType[T]()
if dtype != requestedDType {
var dummy T
err = errors.Errorf("called BufferToArray[%T](...), but underlying buffer has dtype %s", dummy, dtype)
Expand Down
6 changes: 3 additions & 3 deletions xlabuilder/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) *Literal {
if len(dimensions) == 0 {
dimensions = []int{len(flat)}
}
shape := MakeShape(dtypes.FromGoType[T](), dimensions...)
shape := MakeShape(dtypes.FromGenericsType[T](), dimensions...)
if shape.Size() != len(flat) {
exceptions.Panicf("NewArrayLiteral got a slice of length %d, but the shape %s given has %d elements",
len(flat), shape, shape.Size())
Expand All @@ -52,7 +52,7 @@ func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) *Literal {

// NewScalarLiteral creates a scalar Literal initialized with the given value.
func NewScalarLiteral[T dtypes.Supported](value T) *Literal {
shape := MakeShape(dtypes.FromGoType[T]())
shape := MakeShape(dtypes.FromGenericsType[T]())
l := NewLiteralFromShape(shape)
*(*T)(unsafe.Pointer(l.cLiteral.data)) = value
return l
Expand Down Expand Up @@ -88,7 +88,7 @@ func NewScalarLiteralFromFloat64(value float64, dtype dtypes.DType) *Literal {
// It uses reflection to inspect the type.
func NewScalarLiteralFromAny(value any) *Literal {
valueOf := reflect.ValueOf(value)
dtype := dtypes.FromType(valueOf.Type())
dtype := dtypes.FromGoType(valueOf.Type())
l := NewLiteralFromShape(MakeShape(dtype))
lValueOf := reflect.NewAt(dtype.GoType(), unsafe.Pointer(l.cLiteral.data)).Elem()
lValueOf.Set(valueOf)
Expand Down

0 comments on commit bbd150e

Please sign in to comment.