Skip to content

Commit

Permalink
Added dtypes functionality from gomlx.
Browse files Browse the repository at this point in the history
  • Loading branch information
janpfeifer committed Jul 6, 2024
1 parent 40c228a commit 4ffb520
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Next

* Moved some `dtypes` support functionality from GoMLX to Gopjrt.

# v0.1.2 SuppressAbseilLoggingHack

* Improved SuppressAbseilLoggingHack to supress only during the execution of a function.
Expand Down
85 changes: 78 additions & 7 deletions dtypes/dtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ import (
// Generate automatic C-to-Go boilerplate code for pjrt_c_api.h.
//go:generate go run ../cmd/dtypes_codegen

// Supported lists the Go types that `gopjrt` knows how to convert -- there are more types that can be manually
// converted.
// Used as traits for generics.
type Supported interface {
bool | float16.Float16 | float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | complex64 | complex128
}

// FromGoType returns the DType enum for the given type that this package knows about.
func FromGoType[T Supported]() DType {
var t T
Expand Down Expand Up @@ -298,3 +291,81 @@ func (dtype DType) SmallestNonZeroValueForDType() any {
panic(errors.Errorf("SmallestNonZeroValueForDType not defined for dtype %s", dtype))
}
}

// IsFloat returns whether dtype is a supported float -- float types not yet supported will return false.
// It returns false for complex numbers.
func (dtype DType) IsFloat() bool {
return dtype == Float32 || dtype == Float64 || dtype == Float16 || dtype == BFloat16
}

// IsFloat16 returns whether dtype is a supported float with 16 bits: [Float16] or [BFloat16].
func (dtype DType) IsFloat16() bool {
return dtype == Float16 || dtype == BFloat16
}

// IsComplex returns whether dtype is a supported complex number type.
func (dtype DType) IsComplex() bool {
return dtype == Complex64 || dtype == Complex128
}

// RealDType returns the real component of complex dtypes.
// For float dtypes, it returns itself.
//
// It returns InvalidDType for other non-(complex or float) dtypes.
func (dtype DType) RealDType() DType {
if dtype.IsFloat() {
return dtype
}
switch dtype {
case Complex64:
return Float32
case Complex128:
return Float64
default:
// RealDType is not defined for other dtypes.
return InvalidDType
}
}

// IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false.
func (dtype DType) IsInt() bool {
return dtype == Int64 || dtype == Int32 || dtype == Int16 || dtype == Int8 ||
dtype == Uint8 || dtype == Uint16 || dtype == Uint32 || dtype == Uint64
}

// IsSupported returns whether dtype is supported by `gopjrt`.
func (dtype DType) IsSupported() bool {
return dtype == Bool || dtype == Float16 || dtype == Float32 || dtype == Float64 || dtype == Int64 || dtype == Int32 || dtype == Int16 || dtype == Int8 || dtype == Uint32 || dtype == Uint16 || dtype == Uint8 || dtype == Complex64 || dtype == Complex128
}

// Supported lists the Go types that `gopjrt` knows how to convert -- there are more types that can be manually
// converted.
// Used as traits for generics.
//
// Notice Go's `int` type is not portable, since it may translate to dtypes Int32 or Int64 depending
// on the platform.
type Supported interface {
bool | float32 | float64 | float16.Float16 | int | int32 | int64 | uint8 | uint32 | uint64 | complex64 | complex128
}

// Number represents the Go numeric types that are supported by graph package.
// Used as traits for generics.
//
// Notice that "int" becomes int64 in the implementation.
// Since it needs a 1:1 mapping, it gets converted back to int64.
// It includes complex numbers.
type Number interface {
float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64 | complex64 | complex128
}

// NumberNotComplex represents the Go numeric types that are supported by graph package except the complex numbers.
// Used as a Generics constraint.
// See Number for details.
type NumberNotComplex interface {
float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64
}

// GoFloat represent a continuous Go numeric type, supported by GoMLX.
type GoFloat interface {
float32 | float64
}

0 comments on commit 4ffb520

Please sign in to comment.