Skip to content

Commit

Permalink
[CPU]PageAttn with 4bit-quantization (#27992)
Browse files Browse the repository at this point in the history
### Details:
 - *Add new hint to set group_size for key/value cache*
- *Add grouped 4bit sym/asym quantization support for PageAttentionNode*
 - *Add grouped quantization for U8 quantization for PageAttentionNode*

### Tickets:
 - *CVS-151586*

---------

Signed-off-by: [email protected] <[email protected]>
Signed-off-by: Zhang Yi3 <[email protected]>
Signed-off-by: Zhang Yi <[email protected]>
  • Loading branch information
zhangYiIntel authored Jan 8, 2025
1 parent 345163f commit b319014
Show file tree
Hide file tree
Showing 18 changed files with 1,413 additions and 411 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from openvino._pyopenvino.properties import loaded_from_cache
from openvino._pyopenvino.properties import cache_encryption_callbacks
from openvino._pyopenvino.properties import weights_path
from openvino._pyopenvino.properties import key_cache_precision
from openvino._pyopenvino.properties import value_cache_precision
from openvino._pyopenvino.properties import key_cache_group_size
from openvino._pyopenvino.properties import value_cache_group_size

# Submodules
from openvino.runtime.properties import hint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ void regmodule_properties(py::module m) {
wrap_property_RW(m_properties, ov::force_tbb_terminate, "force_tbb_terminate");
wrap_property_RW(m_properties, ov::enable_mmap, "enable_mmap");
wrap_property_RW(m_properties, ov::weights_path, "weights_path");
wrap_property_RW(m_properties, ov::key_cache_precision, "key_cache_precision");
wrap_property_RW(m_properties, ov::value_cache_precision, "value_cache_precision");
wrap_property_RW(m_properties, ov::key_cache_group_size, "key_cache_group_size");
wrap_property_RW(m_properties, ov::value_cache_group_size, "value_cache_group_size");

wrap_property_RO(m_properties, ov::supported_properties, "supported_properties");
wrap_property_RO(m_properties, ov::available_devices, "available_devices");
Expand Down
12 changes: 12 additions & 0 deletions src/bindings/python/tests/test_runtime/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ def test_properties_ro(ov_property_ro, expected_value):
"WEIGHTS_PATH",
(("./model.bin", "./model.bin"),),
),
(
props.key_cache_group_size,
"KEY_CACHE_GROUP_SIZE",
((64, 64),),
),
(
props.value_cache_group_size,
"VALUE_CACHE_GROUP_SIZE",
((64, 64),),
),
(props.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(props.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(hints.inference_precision, "INFERENCE_PRECISION_HINT", ((Type.f32, Type.f32),)),
(
hints.model_priority,
Expand Down
24 changes: 24 additions & 0 deletions src/inference/include/openvino/runtime/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1301,4 +1301,28 @@ static constexpr Property<std::vector<std::string>, PropertyMutability::RO> exec
* @note This property is used for weightless caching. Only used when ov::CacheMode Property is set to "OPTIMIZE_SIZE".
*/
static constexpr Property<std::string, PropertyMutability::RW> weights_path{"WEIGHTS_PATH"};

/**
* @brief The precision of key cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<element::Type, PropertyMutability::RW> key_cache_precision{"KEY_CACHE_PRECISION"};

/**
* @brief The precision of value cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<element::Type, PropertyMutability::RW> value_cache_precision{"VALUE_CACHE_PRECISION"};

/**
* @brief The group_size of key cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<uint64_t, PropertyMutability::RW> key_cache_group_size{"KEY_CACHE_GROUP_SIZE"};

/**
* @brief The group_size of value cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<uint64_t, PropertyMutability::RW> value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"};
} // namespace ov
12 changes: 12 additions & 0 deletions src/plugins/intel_cpu/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const {
RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()),
RO_property(ov::hint::dynamic_quantization_group_size.name()),
RO_property(ov::hint::kv_cache_precision.name()),
RO_property(ov::key_cache_precision.name()),
RO_property(ov::value_cache_precision.name()),
RO_property(ov::key_cache_group_size.name()),
RO_property(ov::value_cache_group_size.name()),
};

return ro_properties;
Expand Down Expand Up @@ -313,6 +317,14 @@ ov::Any CompiledModel::get_property(const std::string& name) const {
return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize);
} else if (name == ov::hint::kv_cache_precision) {
return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision);
} else if (name == ov::key_cache_precision) {
return decltype(ov::key_cache_precision)::value_type(config.keyCachePrecision);
} else if (name == ov::value_cache_precision) {
return decltype(ov::value_cache_precision)::value_type(config.valueCachePrecision);
} else if (name == ov::key_cache_group_size) {
return decltype(ov::key_cache_group_size)::value_type(config.keyCacheGroupSize);
} else if (name == ov::value_cache_group_size) {
return decltype(ov::value_cache_group_size)::value_type(config.valueCacheGroupSize);
}
OPENVINO_THROW("Unsupported property: ", name);
}
Expand Down
86 changes: 85 additions & 1 deletion src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,60 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
ov::hint::kv_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
}
} else if (key == ov::key_cache_precision.name()) {
try {
keyCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) {
keyCachePrecision = prec;
} else {
OPENVINO_THROW("keyCachePrecision doesn't support value ", prec);
}
} catch (ov::Exception&) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::key_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
}
} else if (key == ov::value_cache_precision.name()) {
try {
valueCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (one_of(prec,
ov::element::f32,
ov::element::f16,
ov::element::bf16,
ov::element::u8,
ov::element::u4)) {
valueCachePrecision = prec;
} else {
OPENVINO_THROW("valueCachePrecision doesn't support value ", prec);
}
} catch (ov::Exception&) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::value_cache_precision.name(),
". Supported values: u4, u8, bf16, f16, f32");
}
} else if (key == ov::key_cache_group_size.name() || key == ov::value_cache_group_size.name()) {
try {
auto const groupSize = val.as<uint64_t>();
if (key == ov::key_cache_group_size.name()) {
keyCacheGroupSizeSetExplicitly = true;
keyCacheGroupSize = groupSize;
} else {
valueCacheGroupSizeSetExplicitly = true;
valueCacheGroupSize = groupSize;
}
} catch (ov::Exception&) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
key,
". Expected only unsinged integer numbers");
}
} else if (key == ov::cache_encryption_callbacks.name()) {
try {
const auto& encryption_callbacks = val.as<EncryptionCallbacks>();
Expand Down Expand Up @@ -344,6 +398,13 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
aclFastMath = true;
}
#endif
// key/value cache precision has higher priority, if not defined use kvCachePrecision
if (!keyCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) {
keyCachePrecision = kvCachePrecision;
}
if (!valueCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) {
valueCachePrecision = kvCachePrecision;
}
// disable dynamic quantization and kv quantization for best accuracy
if (executionMode == ov::hint::ExecutionMode::ACCURACY) {
if (!fcDynamicQuantizationGroupSizeSetExplicitly) {
Expand All @@ -352,6 +413,12 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
if (!kvCachePrecisionSetExplicitly) {
kvCachePrecision = ov::element::f32;
}
if (!keyCachePrecisionSetExplicitly) {
keyCachePrecision = ov::element::f32;
}
if (!valueCachePrecisionSetExplicitly) {
valueCachePrecision = ov::element::f32;
}
}

if (!prop.empty())
Expand Down Expand Up @@ -398,14 +465,31 @@ void Config::applyRtInfo(const std::shared_ptr<const ov::Model>& model) {
// if user sets explicitly, it will be higher priority than rt_info
if (!kvCachePrecisionSetExplicitly &&
model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) {
this->kvCachePrecision =
this->kvCachePrecision = this->keyCachePrecision = this->valueCachePrecision =
model->get_rt_info<ov::element::Type>({"runtime_options", ov::hint::kv_cache_precision.name()});
}
if (!fcDynamicQuantizationGroupSizeSetExplicitly &&
model->has_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()})) {
this->fcDynamicQuantizationGroupSize =
model->get_rt_info<uint64_t>({"runtime_options", ov::hint::dynamic_quantization_group_size.name()});
}
if (!keyCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_precision.name()})) {
this->keyCachePrecision =
model->get_rt_info<ov::element::Type>({"runtime_options", ov::key_cache_precision.name()});
}
if (!valueCachePrecisionSetExplicitly &&
model->has_rt_info({"runtime_options", ov::value_cache_precision.name()})) {
this->valueCachePrecision =
model->get_rt_info<ov::element::Type>({"runtime_options", ov::value_cache_precision.name()});
}
if (!keyCacheGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_group_size.name()})) {
this->keyCacheGroupSize = model->get_rt_info<uint64_t>({"runtime_options", ov::key_cache_group_size.name()});
}
if (!valueCacheGroupSizeSetExplicitly &&
model->has_rt_info({"runtime_options", ov::value_cache_group_size.name()})) {
this->valueCacheGroupSize =
model->get_rt_info<uint64_t>({"runtime_options", ov::value_cache_group_size.name()});
}
}

} // namespace intel_cpu
Expand Down
10 changes: 10 additions & 0 deletions src/plugins/intel_cpu/src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,27 @@ struct Config {
uint64_t fcDynamicQuantizationGroupSize = 32;
bool fcDynamicQuantizationGroupSizeSetExplicitly = false;
bool kvCachePrecisionSetExplicitly = false;
bool keyCachePrecisionSetExplicitly = false;
bool valueCachePrecisionSetExplicitly = false;
bool keyCacheGroupSizeSetExplicitly = false;
bool valueCacheGroupSizeSetExplicitly = false;
#if defined(OV_CPU_WITH_ACL)
bool aclFastMath = false;
#endif
#if defined(OPENVINO_ARCH_X86_64)
ov::element::Type kvCachePrecision = ov::element::u8;
ov::element::Type keyCachePrecision = ov::element::u8;
ov::element::Type valueCachePrecision = ov::element::u8;
size_t rtCacheCapacity = 5000ul;
#else
ov::element::Type kvCachePrecision = ov::element::f16;
ov::element::Type keyCachePrecision = ov::element::f16;
ov::element::Type valueCachePrecision = ov::element::f16;
// TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives
size_t rtCacheCapacity = 0ul;
#endif
size_t keyCacheGroupSize = 0ul;
size_t valueCacheGroupSize = 0ul;
ov::threading::IStreamsExecutor::Config streamExecutorConfig;
int streams = 1;
bool streamsChanged = false;
Expand Down
Loading

0 comments on commit b319014

Please sign in to comment.