Skip to content

Commit

Permalink
Added BFloat16 alias.
Browse files Browse the repository at this point in the history
  • Loading branch information
janpfeifer committed Jul 6, 2024
1 parent 4ffb520 commit 9502c13
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
1 change: 1 addition & 0 deletions cmd/dtypes_codegen/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var aliases = map[string]string{
"U32": "Uint32",
"U64": "Uint64",
"F16": "Float16",
"BF16": "BFloat16",
"F32": "Float32",
"F64": "Float64",
"C64": "Complex64",
Expand Down
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Next

* Moved some `dtypes` support functionality from GoMLX to Gopjrt.
* Added BFloat16 alias.

# v0.1.2 SuppressAbseilLoggingHack

Expand Down
6 changes: 3 additions & 3 deletions dtypes/dtype_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions dtypes/gen_dtype_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ const (
// Float64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F64).
Float64 DType = 12

// BF16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_BF16).
// BFloat16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_BF16).
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
// and 7 bits for the mantissa.
BF16 DType = 13
BFloat16 DType = 13

// Complex64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_C64).
// Complex values of fixed width.
Expand Down Expand Up @@ -152,6 +152,9 @@ const (
// F64 (or PJRT_Buffer_Type_F64) is the C enum name for Float64.
F64 = Float64

// BF16 (or PJRT_Buffer_Type_BF16) is the C enum name for BFloat16.
BF16 = BFloat16

// C64 (or PJRT_Buffer_Type_C64) is the C enum name for Complex64.
C64 = Complex64

Expand Down Expand Up @@ -192,7 +195,7 @@ func (dtype DType) PrimitiveType() protos.PrimitiveType {
return protos.PrimitiveType_F32
case Float64:
return protos.PrimitiveType_F64
case BF16:
case BFloat16:
return protos.PrimitiveType_BF16
case Complex64:
return protos.PrimitiveType_C64
Expand Down Expand Up @@ -257,7 +260,7 @@ func FromPrimitiveType(primitiveType protos.PrimitiveType) DType {
case protos.PrimitiveType_F64:
return Float64
case protos.PrimitiveType_BF16:
return BF16
return BFloat16
case protos.PrimitiveType_C64:
return Complex64
case protos.PrimitiveType_C128:
Expand Down

0 comments on commit 9502c13

Please sign in to comment.