From ab04fe4042cfbc8dda0e7f524cbe59f5d16517df Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sat, 20 Jul 2024 05:21:48 +0200 Subject: [PATCH] Use `github.com/dmarkham/enumer` instead of the usual `stringer` for dtypes. --- cmd/dtypes_codegen/enums.go | 2 +- docs/CHANGELOG.md | 1 + dtypes/dtype_enumer.go | 225 ++++++++++++++++++++++++++++++++++++ dtypes/dtype_string.go | 48 -------- 4 files changed, 227 insertions(+), 49 deletions(-) create mode 100644 dtypes/dtype_enumer.go delete mode 100644 dtypes/dtype_string.go diff --git a/cmd/dtypes_codegen/enums.go b/cmd/dtypes_codegen/enums.go index 1e011ba..afb0881 100644 --- a/cmd/dtypes_codegen/enums.go +++ b/cmd/dtypes_codegen/enums.go @@ -166,5 +166,5 @@ func generateEnums(contents string) { must.M(enumsFromCTemplate.Execute(f, allValues)) must.M(exec.Command("gofmt", "-w", DTypeEnumGoFileName).Run()) fmt.Printf("Generated %q based on pjrt_c_api.h\n", DTypeEnumGoFileName) - must.M(exec.Command("stringer", "-type=DType", DTypeEnumGoFileName).Run()) + must.M(exec.Command("enumer", "-type=DType", "-yaml", "-json", "-text", "-values", DTypeEnumGoFileName).Run()) } diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 765ca4d..22cd298 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -2,6 +2,7 @@ * Execute.NonDonatable -> Execute.DonateNone * Added Execute.SetDonate +* Use `github.com/dmarkham/enumer` instead of the usual `stringer` for dtypes. # v0.2.0 GoMLX integration fixes -- GoMLX more extensive tests caught several small issues in Gopjrt. diff --git a/dtypes/dtype_enumer.go b/dtypes/dtype_enumer.go new file mode 100644 index 0000000..ca09ce4 --- /dev/null +++ b/dtypes/dtype_enumer.go @@ -0,0 +1,225 @@ +// Code generated by "enumer -type=DType -yaml -json -text -values gen_dtype_enum.go"; DO NOT EDIT. + +package dtypes + +import ( + "encoding/json" + "fmt" + "strings" +) + +const _DTypeName = "InvalidDTypeBoolInt8Int16Int32Int64Uint8Uint16Uint32Uint64Float16Float32Float64BFloat16Complex64Complex128F8E5M2F8E4M3FNF8E4M3B11FNUZF8E5M2FNUZF8E4M3FNUZS4U4TOKENS2U2" + +var _DTypeIndex = [...]uint8{0, 12, 16, 20, 25, 30, 35, 40, 46, 52, 58, 65, 72, 79, 87, 96, 106, 112, 120, 133, 143, 153, 155, 157, 162, 164, 166} + +const _DTypeLowerName = "invaliddtypeboolint8int16int32int64uint8uint16uint32uint64float16float32float64bfloat16complex64complex128f8e5m2f8e4m3fnf8e4m3b11fnuzf8e5m2fnuzf8e4m3fnuzs4u4tokens2u2" + +func (i DType) String() string { + if i < 0 || i >= DType(len(_DTypeIndex)-1) { + return fmt.Sprintf("DType(%d)", i) + } + return _DTypeName[_DTypeIndex[i]:_DTypeIndex[i+1]] +} + +func (DType) Values() []string { + return DTypeStrings() +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _DTypeNoOp() { + var x [1]struct{} + _ = x[InvalidDType-(0)] + _ = x[Bool-(1)] + _ = x[Int8-(2)] + _ = x[Int16-(3)] + _ = x[Int32-(4)] + _ = x[Int64-(5)] + _ = x[Uint8-(6)] + _ = x[Uint16-(7)] + _ = x[Uint32-(8)] + _ = x[Uint64-(9)] + _ = x[Float16-(10)] + _ = x[Float32-(11)] + _ = x[Float64-(12)] + _ = x[BFloat16-(13)] + _ = x[Complex64-(14)] + _ = x[Complex128-(15)] + _ = x[F8E5M2-(16)] + _ = x[F8E4M3FN-(17)] + _ = x[F8E4M3B11FNUZ-(18)] + _ = x[F8E5M2FNUZ-(19)] + _ = x[F8E4M3FNUZ-(20)] + _ = x[S4-(21)] + _ = x[U4-(22)] + _ = x[TOKEN-(23)] + _ = x[S2-(24)] + _ = x[U2-(25)] +} + +var _DTypeValues = []DType{InvalidDType, Bool, Int8, Int16, Int32, Int64, Uint8, Uint16, Uint32, Uint64, Float16, Float32, Float64, BFloat16, Complex64, Complex128, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, S4, U4, TOKEN, S2, U2} + +var _DTypeNameToValueMap = map[string]DType{ + _DTypeName[0:12]: InvalidDType, + _DTypeLowerName[0:12]: InvalidDType, + _DTypeName[12:16]: Bool, + _DTypeLowerName[12:16]: Bool, + _DTypeName[16:20]: Int8, + _DTypeLowerName[16:20]: Int8, + _DTypeName[20:25]: Int16, + _DTypeLowerName[20:25]: Int16, + _DTypeName[25:30]: Int32, + _DTypeLowerName[25:30]: Int32, + _DTypeName[30:35]: Int64, + _DTypeLowerName[30:35]: Int64, + _DTypeName[35:40]: Uint8, + _DTypeLowerName[35:40]: Uint8, + _DTypeName[40:46]: Uint16, + _DTypeLowerName[40:46]: Uint16, + _DTypeName[46:52]: Uint32, + _DTypeLowerName[46:52]: Uint32, + _DTypeName[52:58]: Uint64, + _DTypeLowerName[52:58]: Uint64, + _DTypeName[58:65]: Float16, + _DTypeLowerName[58:65]: Float16, + _DTypeName[65:72]: Float32, + _DTypeLowerName[65:72]: Float32, + _DTypeName[72:79]: Float64, + _DTypeLowerName[72:79]: Float64, + _DTypeName[79:87]: BFloat16, + _DTypeLowerName[79:87]: BFloat16, + _DTypeName[87:96]: Complex64, + _DTypeLowerName[87:96]: Complex64, + _DTypeName[96:106]: Complex128, + _DTypeLowerName[96:106]: Complex128, + _DTypeName[106:112]: F8E5M2, + _DTypeLowerName[106:112]: F8E5M2, + _DTypeName[112:120]: F8E4M3FN, + _DTypeLowerName[112:120]: F8E4M3FN, + _DTypeName[120:133]: F8E4M3B11FNUZ, + _DTypeLowerName[120:133]: F8E4M3B11FNUZ, + _DTypeName[133:143]: F8E5M2FNUZ, + _DTypeLowerName[133:143]: F8E5M2FNUZ, + _DTypeName[143:153]: F8E4M3FNUZ, + _DTypeLowerName[143:153]: F8E4M3FNUZ, + _DTypeName[153:155]: S4, + _DTypeLowerName[153:155]: S4, + _DTypeName[155:157]: U4, + _DTypeLowerName[155:157]: U4, + _DTypeName[157:162]: TOKEN, + _DTypeLowerName[157:162]: TOKEN, + _DTypeName[162:164]: S2, + _DTypeLowerName[162:164]: S2, + _DTypeName[164:166]: U2, + _DTypeLowerName[164:166]: U2, +} + +var _DTypeNames = []string{ + _DTypeName[0:12], + _DTypeName[12:16], + _DTypeName[16:20], + _DTypeName[20:25], + _DTypeName[25:30], + _DTypeName[30:35], + _DTypeName[35:40], + _DTypeName[40:46], + _DTypeName[46:52], + _DTypeName[52:58], + _DTypeName[58:65], + _DTypeName[65:72], + _DTypeName[72:79], + _DTypeName[79:87], + _DTypeName[87:96], + _DTypeName[96:106], + _DTypeName[106:112], + _DTypeName[112:120], + _DTypeName[120:133], + _DTypeName[133:143], + _DTypeName[143:153], + _DTypeName[153:155], + _DTypeName[155:157], + _DTypeName[157:162], + _DTypeName[162:164], + _DTypeName[164:166], +} + +// DTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func DTypeString(s string) (DType, error) { + if val, ok := _DTypeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _DTypeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to DType values", s) +} + +// DTypeValues returns all values of the enum +func DTypeValues() []DType { + return _DTypeValues +} + +// DTypeStrings returns a slice of all String values of the enum +func DTypeStrings() []string { + strs := make([]string, len(_DTypeNames)) + copy(strs, _DTypeNames) + return strs +} + +// IsADType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i DType) IsADType() bool { + for _, v := range _DTypeValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for DType +func (i DType) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for DType +func (i *DType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("DType should be a string, got %s", data) + } + + var err error + *i, err = DTypeString(s) + return err +} + +// MarshalText implements the encoding.TextMarshaler interface for DType +func (i DType) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for DType +func (i *DType) UnmarshalText(text []byte) error { + var err error + *i, err = DTypeString(string(text)) + return err +} + +// MarshalYAML implements a YAML Marshaler for DType +func (i DType) MarshalYAML() (interface{}, error) { + return i.String(), nil +} + +// UnmarshalYAML implements a YAML Unmarshaler for DType +func (i *DType) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + var err error + *i, err = DTypeString(s) + return err +} diff --git a/dtypes/dtype_string.go b/dtypes/dtype_string.go deleted file mode 100644 index 8785522..0000000 --- a/dtypes/dtype_string.go +++ /dev/null @@ -1,48 +0,0 @@ -// Code generated by "stringer -type=DType gen_dtype_enum.go"; DO NOT EDIT. - -package dtypes - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[InvalidDType-0] - _ = x[Bool-1] - _ = x[Int8-2] - _ = x[Int16-3] - _ = x[Int32-4] - _ = x[Int64-5] - _ = x[Uint8-6] - _ = x[Uint16-7] - _ = x[Uint32-8] - _ = x[Uint64-9] - _ = x[Float16-10] - _ = x[Float32-11] - _ = x[Float64-12] - _ = x[BFloat16-13] - _ = x[Complex64-14] - _ = x[Complex128-15] - _ = x[F8E5M2-16] - _ = x[F8E4M3FN-17] - _ = x[F8E4M3B11FNUZ-18] - _ = x[F8E5M2FNUZ-19] - _ = x[F8E4M3FNUZ-20] - _ = x[S4-21] - _ = x[U4-22] - _ = x[TOKEN-23] - _ = x[S2-24] - _ = x[U2-25] -} - -const _DType_name = "InvalidDTypeBoolInt8Int16Int32Int64Uint8Uint16Uint32Uint64Float16Float32Float64BFloat16Complex64Complex128F8E5M2F8E4M3FNF8E4M3B11FNUZF8E5M2FNUZF8E4M3FNUZS4U4TOKENS2U2" - -var _DType_index = [...]uint8{0, 12, 16, 20, 25, 30, 35, 40, 46, 52, 58, 65, 72, 79, 87, 96, 106, 112, 120, 133, 143, 153, 155, 157, 162, 164, 166} - -func (i DType) String() string { - if i < 0 || i >= DType(len(_DType_index)-1) { - return "DType(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _DType_name[_DType_index[i]:_DType_index[i+1]] -}