Skip to content

Commit

Permalink
Update onnx.proto for int4 (ROCm#3373)
Browse files Browse the repository at this point in the history
  • Loading branch information
lakhinderwalia authored Aug 20, 2024
1 parent 016be6e commit 9cf49f9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 19 deletions.
74 changes: 56 additions & 18 deletions src/onnx/onnx.proto
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,14 @@ enum Version {
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;

// IR VERSION 9 published on TBD
// IR VERSION 9 published on May 5, 2023
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
IR_VERSION_2023_5_5 = 0x0000000000000009;

// IR VERSION 10 published on TBD
// Added UINT4, INT4.
IR_VERSION = 0x000000000000000A;
}

// Attributes
Expand All @@ -116,6 +120,8 @@ enum Version {
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
reserved 12, 16 to 19;
reserved "v";

// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
Expand Down Expand Up @@ -188,6 +194,8 @@ message ValueInfoProto {
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 4;
}

// Nodes
Expand All @@ -202,19 +210,24 @@ message NodeProto {
repeated string output = 2; // namespace Value

// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
// This field MAY be absent in this version of the IR.
optional string name = 3; // namespace Node

// The symbolic identifier of the Operator to execute.
optional string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
optional string domain = 7; // namespace Domain
// Overload identifier, used only to map this to a model-local function.
optional string overload = 8;

// Additional named attributes.
repeated AttributeProto attribute = 5;

// A human-readable documentation for this node. Markdown is allowed.
optional string doc_string = 6;

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 9;
}

// Training information
Expand Down Expand Up @@ -259,7 +272,7 @@ message TrainingInfoProto {
//
// An execution of the training algorithm step is performed by executing the
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// and the "algorithm" graph. That is, the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
Expand Down Expand Up @@ -399,9 +412,9 @@ message ModelProto {

// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// The (domain, name, overload) tuple must be unique across the function protos in this list.
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// or standard operator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
Expand Down Expand Up @@ -473,6 +486,9 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 16;

reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
Expand Down Expand Up @@ -515,10 +531,14 @@ message TensorProto {
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero

// 4-bit data-types
UINT4 = 21; // Unsigned integer in range [0, 15]
INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation

// Future extensions go here.
}
Expand Down Expand Up @@ -553,11 +573,13 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];

// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// For int32, uint8, int8, uint16, int16, uint4, int4, bool, float8 and float16 values
// float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// uint4 and int4 values must be packed to 4bitx2 prior to writing to the buffer, the first element is stored in
// the 4 LSB and the second element is stored in the 4 MSB.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
// INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];

// For strings.
Expand Down Expand Up @@ -587,6 +609,7 @@ message TensorProto {
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
// uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
Expand Down Expand Up @@ -629,6 +652,9 @@ message TensorProto {
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 16;
}

// A serialized sparse-tensor value
Expand Down Expand Up @@ -775,9 +801,8 @@ enum OperatorStatus {
}

message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
// The name of the function, similar to op_type in NodeProto.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
optional string name = 1;

// Deprecated since IR Version 8
Expand Down Expand Up @@ -824,11 +849,24 @@ message FunctionProto {

repeated OperatorSetIdProto opset_import = 9;

// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
// The domain which this function belongs to.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
optional string domain = 10;
}

// The overload identifier of the function.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
optional string overload = 13;

// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
}

// For using protobuf-lite
option optimize_for = LITE_RUNTIME;
option optimize_for = LITE_RUNTIME;

2 changes: 1 addition & 1 deletion tools/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

CLANG_FORMAT_PATH = '/opt/rocm/llvm/bin'

EXCLUDE_FILES = ['requirements.in']
EXCLUDE_FILES = ['requirements.in', 'onnx.proto']


def run(cmd, **kwargs):
Expand Down

0 comments on commit 9cf49f9

Please sign in to comment.