From 9502c13719768e23fe11e9fb6cca7fc4361bd61d Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sat, 6 Jul 2024 08:37:59 +0200 Subject: [PATCH] Added BFloat16 alias. --- cmd/dtypes_codegen/enums.go | 1 + docs/CHANGELOG.md | 1 + dtypes/dtype_string.go | 6 +++--- dtypes/gen_dtype_enum.go | 11 +++++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cmd/dtypes_codegen/enums.go b/cmd/dtypes_codegen/enums.go index 692d12b..1e011ba 100644 --- a/cmd/dtypes_codegen/enums.go +++ b/cmd/dtypes_codegen/enums.go @@ -29,6 +29,7 @@ var aliases = map[string]string{ "U32": "Uint32", "U64": "Uint64", "F16": "Float16", + "BF16": "BFloat16", "F32": "Float32", "F64": "Float64", "C64": "Complex64", diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 15aef5c..2e52eee 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,7 @@ # Next * Moved some `dtypes` support functionality from GoMLX to Gopjrt. +* Added BFloat16 alias. # v0.1.2 SuppressAbseilLoggingHack diff --git a/dtypes/dtype_string.go b/dtypes/dtype_string.go index 73f1438..8785522 100644 --- a/dtypes/dtype_string.go +++ b/dtypes/dtype_string.go @@ -21,7 +21,7 @@ func _() { _ = x[Float16-10] _ = x[Float32-11] _ = x[Float64-12] - _ = x[BF16-13] + _ = x[BFloat16-13] _ = x[Complex64-14] _ = x[Complex128-15] _ = x[F8E5M2-16] @@ -36,9 +36,9 @@ func _() { _ = x[U2-25] } -const _DType_name = "InvalidDTypeBoolInt8Int16Int32Int64Uint8Uint16Uint32Uint64Float16Float32Float64BF16Complex64Complex128F8E5M2F8E4M3FNF8E4M3B11FNUZF8E5M2FNUZF8E4M3FNUZS4U4TOKENS2U2" +const _DType_name = "InvalidDTypeBoolInt8Int16Int32Int64Uint8Uint16Uint32Uint64Float16Float32Float64BFloat16Complex64Complex128F8E5M2F8E4M3FNF8E4M3B11FNUZF8E5M2FNUZF8E4M3FNUZS4U4TOKENS2U2" -var _DType_index = [...]uint8{0, 12, 16, 20, 25, 30, 35, 40, 46, 52, 58, 65, 72, 79, 83, 92, 102, 108, 116, 129, 139, 149, 151, 153, 158, 160, 162} +var _DType_index = [...]uint8{0, 12, 16, 20, 25, 30, 35, 40, 46, 52, 58, 65, 72, 79, 87, 96, 106, 112, 120, 133, 143, 153, 155, 157, 162, 164, 166} func (i DType) String() string { if i < 0 || i >= DType(len(_DType_index)-1) { diff --git a/dtypes/gen_dtype_enum.go b/dtypes/gen_dtype_enum.go index 517a3d6..a6d894b 100644 --- a/dtypes/gen_dtype_enum.go +++ b/dtypes/gen_dtype_enum.go @@ -61,11 +61,11 @@ const ( // Float64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F64). Float64 DType = 12 - // BF16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_BF16). + // BFloat16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_BF16). // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent // and 7 bits for the mantissa. - BF16 DType = 13 + BFloat16 DType = 13 // Complex64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_C64). // Complex values of fixed width. @@ -152,6 +152,9 @@ const ( // F64 (or PJRT_Buffer_Type_F64) is the C enum name for Float64. F64 = Float64 + // BF16 (or PJRT_Buffer_Type_BF16) is the C enum name for BFloat16. + BF16 = BFloat16 + // C64 (or PJRT_Buffer_Type_C64) is the C enum name for Complex64. C64 = Complex64 @@ -192,7 +195,7 @@ func (dtype DType) PrimitiveType() protos.PrimitiveType { return protos.PrimitiveType_F32 case Float64: return protos.PrimitiveType_F64 - case BF16: + case BFloat16: return protos.PrimitiveType_BF16 case Complex64: return protos.PrimitiveType_C64 @@ -257,7 +260,7 @@ func FromPrimitiveType(primitiveType protos.PrimitiveType) DType { case protos.PrimitiveType_F64: return Float64 case protos.PrimitiveType_BF16: - return BF16 + return BFloat16 case protos.PrimitiveType_C64: return Complex64 case protos.PrimitiveType_C128: