Skip to content

Commit

Permalink
Support INFERENCE_NUM_THREADS options in LLAMA_CPP (#898)
Browse files Browse the repository at this point in the history
* Support INFERENCE_NUM_THREADS options in LLAMA_CPP

* Run func tests during build

* Update modules/llama_cpp_plugin/tests/functional/src/threading.cpp

Co-authored-by: Ilya Lavrenov <[email protected]>

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
vshampor and ilya-lavrenov authored Apr 17, 2024
1 parent 8f23817 commit cb1c767
Show file tree
Hide file tree
Showing 19 changed files with 382 additions and 27 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/llama_cpp_plugin_build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: cmake -B build -DCMAKE_BUILD_TYPE=Release -DOPENVINO_EXTRA_MODULES=${{ github.workspace }}/openvino_contrib/modules/llama_cpp_plugin -DENABLE_TESTS=ON -DENABLE_FUNCTIONAL_TESTS=ON -DENABLE_PLUGINS_XML=ON -DENABLE_LLAMA_CPP_PLUGIN_REGISTRATION=ON openvino

- name: CMake - build
run: cmake --build build -j`nproc` -- llama_cpp_plugin llama_cpp_e2e_tests
run: cmake --build build -j`nproc` -- llama_cpp_plugin llama_cpp_e2e_tests llama_cpp_func_tests


- name: Upload build artifacts
Expand Down Expand Up @@ -69,6 +69,12 @@ jobs:
mkdir -p tbb
tar xvzf oneapi-tbb-2021.2.4-lin.tgz
- name: Run functional tests
run: |
chmod +x ${{ github.workspace }}/binaries/llama_cpp_func_tests
export LD_LIBRARY_PATH=${{ github.workspace }}/binaries:${{ github.workspace }}/tbb/lib
${{ github.workspace }}/binaries/llama_cpp_func_tests
- name: Run E2E tests
run: |
chmod +x ${{ github.workspace }}/binaries/llama_cpp_e2e_tests
Expand Down
2 changes: 2 additions & 0 deletions modules/llama_cpp_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ FetchContent_MakeAvailable(llama_cpp)
if(ENABLE_TESTS)
include(CTest)
enable_testing()
add_subdirectory(tests/common)
add_subdirectory(tests/e2e)
add_subdirectory(tests/functional)
endif()

# install
Expand Down
2 changes: 1 addition & 1 deletion modules/llama_cpp_plugin/include/compiled_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LlamaCppPlugin;
class LlamaCppState;
class LlamaCppModel : public ICompiledModel {
public:
LlamaCppModel(const std::string& gguf_fname, const std::shared_ptr<const IPlugin>& plugin);
LlamaCppModel(const std::string& gguf_fname, const std::shared_ptr<const IPlugin>& plugin, size_t num_threads = 0);
/**
* @brief Export compiled model to stream
*
Expand Down
3 changes: 3 additions & 0 deletions modules/llama_cpp_plugin/include/plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class LlamaCppPlugin : public IPlugin {

virtual ov::SupportedOpsMap query_model(const std::shared_ptr<const ov::Model>& model,
const ov::AnyMap& properties) const override;

private:
size_t m_num_threads = 0;
};
} // namespace llama_cpp_plugin
} // namespace ov
Expand Down
7 changes: 4 additions & 3 deletions modules/llama_cpp_plugin/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ LlamaCppModel::~LlamaCppModel() {
llama_backend_free();
}

LlamaCppModel::LlamaCppModel(const std::string& gguf_fname, const std::shared_ptr<const IPlugin>& plugin)
LlamaCppModel::LlamaCppModel(const std::string& gguf_fname,
const std::shared_ptr<const IPlugin>& plugin,
size_t num_threads)
: ICompiledModel(nullptr, plugin),
m_gguf_fname(gguf_fname) {
OPENVINO_DEBUG << "llama_cpp_plugin: loading llama model directly from GGUF... " << std::endl;
llama_model_params mparams = llama_model_default_params();
mparams.n_gpu_layers = 99;
m_llama_model_ptr = llama_load_model_from_file(gguf_fname.c_str(), mparams);
llama_context_params cparams = llama_context_default_params();
cparams.n_threads =
std::thread::hardware_concurrency(); // TODO (vshampor): reuse equivalent setting defined by OV API
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
m_llama_ctx = llama_new_context_with_model(m_llama_model_ptr, cparams);
OPENVINO_DEBUG << "llama_cpp_plugin: llama model loaded successfully from GGUF..." << std::endl;
Expand Down
25 changes: 23 additions & 2 deletions modules/llama_cpp_plugin/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,36 @@ std::shared_ptr<ov::ICompiledModel> LlamaCppPlugin::compile_model(const std::sha
}
std::shared_ptr<ov::ICompiledModel> LlamaCppPlugin::compile_model(const std::string& fname,
const ov::AnyMap& properties) const {
return std::make_shared<LlamaCppModel>(fname, shared_from_this());
size_t num_threads = 0;
auto it = properties.find(ov::inference_num_threads.name());
if (it != properties.end()) {
num_threads = it->second.as<int>();
if (num_threads < 0) {
OPENVINO_THROW("INFERENCE_NUM_THREADS cannot be negative");
}
} else {
num_threads = m_num_threads;
}
return std::make_shared<LlamaCppModel>(fname, shared_from_this(), num_threads);
}

void LlamaCppPlugin::set_property(const ov::AnyMap& properties) {
for (const auto& map_entry : properties) {
if (ov::inference_num_threads == map_entry.first) {
int num_threads = map_entry.second.as<int>();
if (num_threads < 0) {
OPENVINO_THROW("INFERENCE_NUM_THREADS cannot be negative");
}
m_num_threads = num_threads;
}
OPENVINO_THROW_NOT_IMPLEMENTED("llama_cpp_plugin: setting property ", map_entry.first, "not implemented");
}
}

ov::Any LlamaCppPlugin::get_property(const std::string& name, const ov::AnyMap& arguments) const {
if (ov::supported_properties == name) {
return decltype(ov::supported_properties)::value_type(
std::vector<PropertyName>({ov::device::capabilities, ov::device::full_name}));
std::vector<PropertyName>({ov::device::capabilities, ov::device::full_name, ov::inference_num_threads}));
}
if (ov::device::capabilities == name) {
return decltype(ov::device::capabilities)::value_type(
Expand All @@ -66,6 +83,10 @@ ov::Any LlamaCppPlugin::get_property(const std::string& name, const ov::AnyMap&
return std::string("LLAMA_CPP");
}

if (ov::inference_num_threads == name) {
return m_num_threads;
}

OPENVINO_THROW_NOT_IMPLEMENTED("llama_cpp_plugin: getting property ", name, "not implemented");
}

Expand Down
8 changes: 8 additions & 0 deletions modules/llama_cpp_plugin/tests/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
project(llama_cpp_test_common)

add_library(llama_cpp_test_common STATIC
${CMAKE_CURRENT_SOURCE_DIR}/src/llm_inference.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/benchmarking.cpp
)
target_include_directories(llama_cpp_test_common PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(llama_cpp_test_common gtest common_test_utils)
12 changes: 12 additions & 0 deletions modules/llama_cpp_plugin/tests/common/include/benchmarking.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#ifndef BENCHMARKING_HPP
#define BENCHMARKING_HPP

#include <functional>
#include <vector>

double measure_iterations_per_second(std::function<void(void)> iteration_fn, size_t iterations);

#endif /* BENCHMARKING_HPP */
22 changes: 22 additions & 0 deletions modules/llama_cpp_plugin/tests/common/include/llm_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef LLM_INFERENCE_HPP
#define LLM_INFERENCE_HPP

#include "model_fixture.hpp"
#include "openvino/openvino.hpp"

std::vector<float> infer_logits_for_tokens_with_positions(ov::InferRequest& lm,
const std::vector<int64_t>& tokens,
int64_t position_ids_start_value);

std::vector<int64_t> generate_n_tokens_with_positions(ov::InferRequest& lm,
int64_t last_token,
size_t n_tokens,
int64_t position_ids_start_value);

inline int64_t get_token_from_logits(const std::vector<float>& logits) {
return std::max_element(logits.cbegin(), logits.cend()) - logits.cbegin();
}
#endif /* LLM_INFERENCE_HPP */
38 changes: 38 additions & 0 deletions modules/llama_cpp_plugin/tests/common/include/model_fixture.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#ifndef MODEL_FIXTURE_HPP
#define MODEL_FIXTURE_HPP

#include <gtest/gtest.h>

#include "common_test_utils/file_utils.hpp"
#include "openvino/openvino.hpp"
#include "openvino/runtime/infer_request.hpp"

const std::string TEST_FILES_DIR = "test_data";
const auto SEP = ov::util::FileTraits<char>::file_separator;

class CompiledModelTest : public ::testing::Test {
public:
static void fill_unused_inputs(ov::InferRequest& infer_request, const ov::Shape& input_ids_reference_shape) {
infer_request.set_tensor("attention_mask", ov::Tensor(ov::element::Type_t::i64, input_ids_reference_shape));

size_t batch_size = input_ids_reference_shape[0];
infer_request.set_tensor("beam_idx", ov::Tensor(ov::element::Type_t::i32, ov::Shape{batch_size}));
}

protected:
void SetUp() override {
const std::string plugin_name = "LLAMA_CPP";
ov::Core core;

const std::string model_file_name = "gpt2.gguf";
const std::string model_file =
ov::test::utils::getCurrentWorkingDir() + SEP + TEST_FILES_DIR + SEP + model_file_name;
model = core.compile_model(model_file, plugin_name);
}
ov::CompiledModel model;
};

#endif /* MODEL_FIXTURE_HPP */
26 changes: 26 additions & 0 deletions modules/llama_cpp_plugin/tests/common/src/benchmarking.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#ifndef BENCHMARKING_CPP
#define BENCHMARKING_CPP

#include "benchmarking.hpp"

#include <algorithm>
#include <chrono>

double measure_iterations_per_second(std::function<void(void)> iteration_fn, size_t iterations) {
std::vector<float> iteration_times_s(iterations);

for (size_t i = 0; i < iterations; i++) {
auto start = std::chrono::steady_clock::now();
iteration_fn();
auto end = std::chrono::steady_clock::now();
iteration_times_s[i] = std::chrono::duration<double>(end - start).count();
}

std::sort(iteration_times_s.begin(), iteration_times_s.end());
return 1.0 / iteration_times_s[iteration_times_s.size() / 2];
}

#endif /* BENCHMARKING_CPP */
42 changes: 42 additions & 0 deletions modules/llama_cpp_plugin/tests/common/src/llm_inference.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "llm_inference.hpp"

std::vector<float> infer_logits_for_tokens_with_positions(ov::InferRequest& lm,
const std::vector<int64_t>& tokens,
int64_t position_ids_start_value) {
auto input_ids_tensor = ov::Tensor(ov::element::Type_t::i64, {1, tokens.size()});
std::copy(tokens.begin(), tokens.end(), input_ids_tensor.data<int64_t>());
lm.set_tensor("input_ids", input_ids_tensor);

ov::Tensor position_ids = lm.get_tensor("position_ids");
position_ids.set_shape(input_ids_tensor.get_shape());
std::iota(position_ids.data<int64_t>(),
position_ids.data<int64_t>() + position_ids.get_size(),
position_ids_start_value);

CompiledModelTest::fill_unused_inputs(lm, input_ids_tensor.get_shape());
lm.infer();

size_t vocab_size = lm.get_tensor("logits").get_shape().back();
float* logits = lm.get_tensor("logits").data<float>() + (input_ids_tensor.get_size() - 1) * vocab_size;
std::vector<float> logits_vector(vocab_size);
std::copy(logits, logits + vocab_size, logits_vector.begin());
return logits_vector;
}

std::vector<int64_t> generate_n_tokens_with_positions(ov::InferRequest& lm,
int64_t last_token,
size_t n_tokens,
int64_t position_ids_start_value) {
size_t cnt = 0;
std::vector<int64_t> out_token_ids;
out_token_ids.push_back(last_token);

while (cnt < n_tokens) {
std::vector<float> logits_curr =
infer_logits_for_tokens_with_positions(lm, {out_token_ids.back()}, cnt + position_ids_start_value);
int64_t out_token = std::max_element(logits_curr.begin(), logits_curr.end()) - logits_curr.begin();
out_token_ids.push_back(out_token);
cnt++;
}
return out_token_ids;
}
7 changes: 5 additions & 2 deletions modules/llama_cpp_plugin/tests/e2e/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ ov_add_test_target(
llama_cpp_plugin
LINK_LIBRARIES
openvino::runtime::dev
openvino::funcSharedTests
common_test_utils
gtest
llama_cpp_test_common
INCLUDES
"${LlamaCppPlugin_SOURCE_DIR}/include"
"${LlamaCppPlugin_SOURCE_DIR}/tests/common/include"
ADD_CLANG_FORMAT
LABELS
OV UNIT TEMPLATE
OV UNIT LLAMA_CPP
)

Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@

#include <gtest/gtest.h>

#include "common_test_utils/file_utils.hpp"
#include "openvino/openvino.hpp"

const std::string TEST_FILES_DIR = "test_data";
#include "model_fixture.hpp"

// "Why is the Sun yellow?"
const std::vector<int64_t> GPT2_PROMPT_TOKEN_IDS = {5195, 318, 262, 3825, 7872, 30};
Expand All @@ -16,34 +13,24 @@ const std::vector<int64_t> GPT2_REFERENCE_RESPONSE_TOKEN_IDS = {
198, 464, 3825, 318, 257, 6016, 2266, 11, 543, 1724, 340, 318, 257, 6016, 2266, 13,
383, 3825, 318, 257, 6016, 2266, 780, 340, 318, 257, 6016, 2266, 13, 198, 198, 464};

const auto SEP = ov::util::FileTraits<char>::file_separator;

TEST(PromptResponseTest, TestGPT2) {
const std::string plugin_name = "LLAMA_CPP";
ov::Core core;

const std::string model_file_name = "gpt2.gguf";
const std::string model_file =
ov::test::utils::getCurrentWorkingDir() + SEP + TEST_FILES_DIR + SEP + model_file_name;
ov::InferRequest lm = core.compile_model(model_file, plugin_name).create_infer_request();
TEST_F(CompiledModelTest, TestPromptResponseGPT2) {
ov::InferRequest lm = model.create_infer_request();
auto input_ids_tensor = ov::Tensor(ov::element::Type_t::i64, {1, GPT2_PROMPT_TOKEN_IDS.size()});
std::copy(GPT2_PROMPT_TOKEN_IDS.begin(), GPT2_PROMPT_TOKEN_IDS.end(), input_ids_tensor.data<int64_t>());
lm.set_tensor("input_ids", input_ids_tensor);
lm.set_tensor("attention_mask", ov::Tensor(ov::element::Type_t::i64, {1, GPT2_PROMPT_TOKEN_IDS.size()}));
ov::Tensor position_ids = lm.get_tensor("position_ids");
position_ids.set_shape(input_ids_tensor.get_shape());
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), 0);

constexpr size_t BATCH_SIZE = 1;
lm.get_tensor("beam_idx").set_shape({BATCH_SIZE});
lm.get_tensor("beam_idx").data<int32_t>()[0] = 0;
fill_unused_inputs(lm, input_ids_tensor.get_shape());

lm.infer();

size_t vocab_size = lm.get_tensor("logits").get_shape().back();
float* logits = lm.get_tensor("logits").data<float>() + (input_ids_tensor.get_size() - 1) * vocab_size;
int64_t out_token = std::max_element(logits, logits + vocab_size) - logits;

constexpr size_t BATCH_SIZE = 1;
lm.get_tensor("input_ids").set_shape({BATCH_SIZE, 1});
position_ids.set_shape({BATCH_SIZE, 1});

Expand Down
23 changes: 23 additions & 0 deletions modules/llama_cpp_plugin/tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

set(TARGET_NAME llama_cpp_func_tests)

ov_add_test_target(
NAME ${TARGET_NAME}
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDENCIES
llama_cpp_plugin
LINK_LIBRARIES
openvino::runtime::dev
common_test_utils
gtest
llama_cpp_test_common
INCLUDES
"${LlamaCppPlugin_SOURCE_DIR}/include"
"${LlamaCppPlugin_SOURCE_DIR}/tests/common/include"
ADD_CLANG_FORMAT
LABELS
OV UNIT LLAMA_CPP
)

Loading

0 comments on commit cb1c767

Please sign in to comment.