diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 783101e839b1..c05b22c5aa88 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -31,6 +31,11 @@ def TPU_Dialect : Dialect { let cppNamespace = "::mlir::tpu"; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + static StringRef GetCoreTypeKey() { return "tpu.core_type"; } + + static std::optional GetCoreTypeAttr(Operation *op); + }]; } class TPU_Attr traits = []> @@ -46,6 +51,19 @@ class TPU_Type traits = []> let mnemonic = mnemonic_; } +def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ + I32EnumAttrCase<"kTc", 0, "tc">, + I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, + I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_CoreTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index df00093fabe6..d884ef197cda 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -68,6 +69,18 @@ void TPUDialect::initialize() { >(); } +/* static */ std::optional TPUDialect::GetCoreTypeAttr( + Operation *op) { + Attribute attr = op->getAttr(GetCoreTypeKey()); + if (attr == nullptr) { + return std::nullopt; + } + if (!mlir::isa(attr)) { + return std::nullopt; + } + return mlir::cast(attr).getValue(); +} + void VectorLayoutAttr::print(AsmPrinter &printer) const { printer << '<'; printer << getLayout();