Skip to content

Commit

Permalink
Revert "Add kwargs to init() in all Python pipelines." (openvinotoolk…
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Oct 21, 2024
1 parent 2756b29 commit 7bca9d2
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main():
# Cache compiled models on disk for GPU to save time on the
# next run. It's not beneficial for CPU.
enable_compile_cache["CACHE_DIR"] = "vlm_cache"
pipe = openvino_genai.VLMPipeline(args.model_dir, device, **enable_compile_cache)
pipe = openvino_genai.VLMPipeline(args.model_dir, device, enable_compile_cache)

config = openvino_genai.GenerationConfig()
config.max_new_tokens = 100
Expand Down
57 changes: 18 additions & 39 deletions src/python/py_generate_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,18 +402,17 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
m.doc() = "Pybind11 binding for LLM Pipeline";

py::class_<LLMPipeline>(m, "LLMPipeline", "This class is used for generation with LLMs")
// init(model_path, tokenizer, device, config, kwargs) should be defined before init(model_path, device, config, kwargs)
// to prevent tokenizer treated as kwargs argument
.def(py::init([](
const std::string& model_path,
const std::string& device,
const py::kwargs& kwargs
const std::map<std::string, py::object>& config
) {
ScopedVar env_manager(utils::ov_tokenizers_module_path());
return std::make_unique<LLMPipeline>(model_path, device, utils::kwargs_to_any_map(kwargs));
return std::make_unique<LLMPipeline>(model_path, device, utils::properties_to_any_map(config));
}),
py::arg("model_path"), "folder with openvino_model.xml and openvino_tokenizer[detokenizer].xml files",
py::arg("device") = "CPU", "device on which inference will be done",
py::arg("config") = ov::AnyMap({}), "openvino.properties map",
R"(
LLMPipeline class constructor.
model_path (str): Path to the model file.
Expand All @@ -422,60 +421,40 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
)")

.def(py::init([](
const std::string& model_path,
const Tokenizer& tokenizer,
const std::string& model_path,
const std::string& device,
const std::map<std::string, py::object>& config,
const py::kwargs& kwargs
) {
ScopedVar env_manager(utils::ov_tokenizers_module_path());
auto kwargs_properies = utils::kwargs_to_any_map(kwargs);
if (config.size()) {
PyErr_WarnEx(PyExc_DeprecationWarning,
"'config' parameters is deprecated, please use kwargs to pass config properties instead.",
1);
auto properies = utils::properties_to_any_map(config);
kwargs_properies.insert(properies.begin(), properies.end());
}
return std::make_unique<LLMPipeline>(model_path, tokenizer, device, kwargs_properies);
return std::make_unique<LLMPipeline>(model_path, device, utils::kwargs_to_any_map(kwargs));
}),
py::arg("model_path"),
py::arg("tokenizer"),
py::arg("device") = "CPU",
py::arg("config") = ov::AnyMap({}), "openvino.properties map",
py::arg("model_path"), "folder with openvino_model.xml and openvino_tokenizer[detokenizer].xml files",
py::arg("device") = "CPU", "device on which inference will be done",
R"(
LLMPipeline class constructor for manualy created openvino_genai.Tokenizer.
LLMPipeline class constructor.
model_path (str): Path to the model file.
tokenizer (openvino_genai.Tokenizer): tokenizer object.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties.
Add {"scheduler_config": ov_genai.SchedulerConfig} to config properties to create continuous batching pipeline.
)")

.def(py::init([](
const std::string& model_path,
const std::string& model_path,
const Tokenizer& tokenizer,
const std::string& device,
const std::map<std::string, py::object>& config,
const py::kwargs& kwargs
const std::map<std::string, py::object>& config
) {
ScopedVar env_manager(utils::ov_tokenizers_module_path());
auto kwargs_properies = utils::kwargs_to_any_map(kwargs);
if (config.size()) {
PyErr_WarnEx(PyExc_DeprecationWarning,
"'config' parameters is deprecated, please use kwargs to pass config properties instead.",
1);
auto properies = utils::properties_to_any_map(config);
kwargs_properies.insert(properies.begin(), properies.end());
}
return std::make_unique<LLMPipeline>(model_path, device, kwargs_properies);
return std::make_unique<LLMPipeline>(model_path, tokenizer, device, utils::properties_to_any_map(config));
}),
py::arg("model_path"), "folder with openvino_model.xml and openvino_tokenizer[detokenizer].xml files",
py::arg("device") = "CPU", "device on which inference will be done",
py::arg("model_path"),
py::arg("tokenizer"),
py::arg("device") = "CPU",
py::arg("config") = ov::AnyMap({}), "openvino.properties map",
R"(
LLMPipeline class constructor.
LLMPipeline class constructor for manualy created openvino_genai.Tokenizer.
model_path (str): Path to the model file.
tokenizer (openvino_genai.Tokenizer): tokenizer object.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties.
Add {"scheduler_config": ov_genai.SchedulerConfig} to config properties to create continuous batching pipeline.
)")

Expand Down
2 changes: 2 additions & 0 deletions src/python/py_text2image_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ ov::AnyMap text2image_kwargs_to_any_map(const py::kwargs& kwargs, bool allow_com
"Use help(openvino_genai.Text2ImagePipeline.generate) to get list of acceptable parameters."));
}
}


}
return params;
}
Expand Down
67 changes: 29 additions & 38 deletions src/python/py_vlm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,76 +55,67 @@ auto vlm_generate_kwargs_docstring = R"(
:rtype: DecodedResults
)";

py::object call_vlm_generate(
ov::genai::VLMPipeline& pipe,
const std::string& prompt,
const std::vector<ov::Tensor>& images,
const ov::genai::GenerationConfig& generation_config,
const utils::PyBindStreamerVariant& py_streamer,
const py::kwargs& kwargs
) {
auto updated_config = *ov::genai::pybind::utils::update_config_from_kwargs(generation_config, kwargs);
ov::genai::StreamerVariant streamer = ov::genai::pybind::utils::pystreamer_to_streamer(py_streamer);

return py::cast(pipe.generate(prompt, images, updated_config, streamer));
}

ov::AnyMap vlm_kwargs_to_any_map(const py::kwargs& kwargs, bool allow_compile_properties=true) {
py::object call_vlm_generate(
ov::genai::VLMPipeline& pipe,
const std::string& prompt,
const py::kwargs& kwargs
) {
ov::AnyMap params = {};

for (const auto& item : kwargs) {
std::string key = py::cast<std::string>(item.first);
py::object value = py::cast<py::object>(item.second);

if (key == "images") {
params.insert({ov::genai::images(std::move(py::cast<std::vector<ov::Tensor>>(value)))});
params.insert({ov::genai::images(std::move(py::cast<std::vector<ov::Tensor>>(item.second)))});
} else if (key == "image") {
params.insert({ov::genai::image(std::move(py::cast<ov::Tensor>(value)))});
params.insert({ov::genai::image(std::move(py::cast<ov::Tensor>(item.second)))});
} else if (key == "generation_config") {
params.insert({ov::genai::generation_config(std::move(py::cast<ov::genai::GenerationConfig>(value)))});
params.insert({ov::genai::generation_config(std::move(py::cast<ov::genai::GenerationConfig>(item.second)))});
} else if (key == "streamer") {
auto py_streamer = py::cast<utils::PyBindStreamerVariant>(value);
params.insert({ov::genai::streamer(std::move(ov::genai::pybind::utils::pystreamer_to_streamer(py_streamer)))});
}
else {
if (allow_compile_properties) {
// convert arbitrary objects to ov::Any
// not supported properties are not checked, as these properties are passed to compile(), which will throw exception in case of unsupported property
if (utils::py_object_is_any_map(value)) {
auto map = utils::py_object_to_any_map(value);
params.insert(map.begin(), map.end());
} else {
params[key] = utils::py_object_to_any(value);
}
}
else {
// generate doesn't run compile(), so only VLMPipeline specific properties are allowed
throw(std::invalid_argument("'" + key + "' is unexpected parameter name. "

} else {
throw(std::invalid_argument("'" + key + "' is unexpected parameter name. "
"Use help(openvino_genai.VLMPipeline.generate) to get list of acceptable parameters."));
}
}
}
return params;
}

py::object call_vlm_generate(
ov::genai::VLMPipeline& pipe,
const std::string& prompt,
const std::vector<ov::Tensor>& images,
const ov::genai::GenerationConfig& generation_config,
const utils::PyBindStreamerVariant& py_streamer,
const py::kwargs& kwargs
) {
auto updated_config = *ov::genai::pybind::utils::update_config_from_kwargs(generation_config, kwargs);
ov::genai::StreamerVariant streamer = ov::genai::pybind::utils::pystreamer_to_streamer(py_streamer);

return py::cast(pipe.generate(prompt, images, updated_config, streamer));
return py::cast(pipe.generate(prompt, params));
}

void init_vlm_pipeline(py::module_& m) {
py::class_<ov::genai::VLMPipeline>(m, "VLMPipeline", "This class is used for generation with VLMs")
.def(py::init([](
const std::string& model_path,
const std::string& device,
const py::kwargs& kwargs
const std::map<std::string, py::object>& config
) {
ScopedVar env_manager(utils::ov_tokenizers_module_path());
return std::make_unique<ov::genai::VLMPipeline>(model_path, device, vlm_kwargs_to_any_map(kwargs, true));
return std::make_unique<ov::genai::VLMPipeline>(model_path, device, utils::properties_to_any_map(config));
}),
py::arg("model_path"), "folder with exported model files",
py::arg("device") = "CPU", "device on which inference will be done",
py::arg("config") = ov::AnyMap({}), "openvino.properties map"
R"(
VLMPipeline class constructor.
model_path (str): Path to the folder with exported model files.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties
)")

.def("start_chat", &ov::genai::VLMPipeline::start_chat, py::arg("system_message") = "")
Expand Down Expand Up @@ -155,7 +146,7 @@ void init_vlm_pipeline(py::module_& m) {
const std::string& prompt,
const py::kwargs& kwargs
) {
return py::cast(pipe.generate(prompt, vlm_kwargs_to_any_map(kwargs, false)));
return call_vlm_generate(pipe, prompt, kwargs);
},
py::arg("prompt"), "Input string",
(vlm_generate_kwargs_docstring + std::string(" \n ")).c_str()
Expand Down
37 changes: 18 additions & 19 deletions src/python/py_whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,44 +255,43 @@ void init_whisper_pipeline(py::module_& m) {
.def_readonly("chunks", &WhisperDecodedResults::chunks);

py::class_<WhisperPipeline>(m, "WhisperPipeline")
// init(model_path, tokenizer, device, kwargs) should be defined before init(model_path, device, kwargs)
// to prevent tokenizer treated as kwargs argument
.def(py::init([](const std::string& model_path,
const Tokenizer& tokenizer,
const std::string& device,
const py::kwargs& kwargs) {
return std::make_unique<WhisperPipeline>(model_path,
tokenizer,
device,
utils::kwargs_to_any_map(kwargs));
const std::map<std::string, py::object>& config) {
ScopedVar env_manager(utils::ov_tokenizers_module_path());
return std::make_unique<WhisperPipeline>(model_path, device, utils::properties_to_any_map(config));
}),
py::arg("model_path"),
py::arg("tokenizer"),
"folder with openvino_model.xml and openvino_tokenizer[detokenizer].xml files",
py::arg("device") = "CPU",
"device on which inference will be done",
py::arg("config") = ov::AnyMap({}),
"openvino.properties map",
R"(
WhisperPipeline class constructor for manualy created openvino_genai.Tokenizer.
WhisperPipeline class constructor.
model_path (str): Path to the model file.
tokenizer (openvino_genai.Tokenizer): tokenizer object.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties.
)")

.def(py::init([](const std::string& model_path,
const Tokenizer& tokenizer,
const std::string& device,
const py::kwargs& kwargs) {
ScopedVar env_manager(utils::ov_tokenizers_module_path());
return std::make_unique<WhisperPipeline>(model_path, device, utils::kwargs_to_any_map(kwargs));
const std::map<std::string, py::object>& config) {
return std::make_unique<WhisperPipeline>(model_path,
tokenizer,
device,
utils::properties_to_any_map(config));
}),
py::arg("model_path"),
"folder with openvino_model.xml and openvino_tokenizer[detokenizer].xml files",
py::arg("tokenizer"),
py::arg("device") = "CPU",
"device on which inference will be done",
py::arg("config") = ov::AnyMap({}),
"openvino.properties map",
R"(
WhisperPipeline class constructor.
WhisperPipeline class constructor for manualy created openvino_genai.Tokenizer.
model_path (str): Path to the model file.
tokenizer (openvino_genai.Tokenizer): tokenizer object.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties.
)")

.def(
Expand Down
4 changes: 2 additions & 2 deletions tests/python_tests/ov_genai_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def read_model(params, **tokenizer_kwargs):
path,
tokenizer,
opt_model,
ov_genai.LLMPipeline(str(path), device='CPU', **{"ENABLE_MMAP": False}),
ov_genai.LLMPipeline(str(path), device='CPU', config={"ENABLE_MMAP": False}),
)


Expand Down Expand Up @@ -252,4 +252,4 @@ def load_pipe(configs: List[Tuple], temp_path):
def get_continuous_batching(path):
scheduler_config = ov_genai.SchedulerConfig()
scheduler_config.cache_size = 1
return ov_genai.LLMPipeline(str(path), ov_genai.Tokenizer(str(path)), device='CPU', **{"scheduler_config": scheduler_config})
return ov_genai.LLMPipeline(str(path), ov_genai.Tokenizer(str(path)), device='CPU', config={"scheduler_config": scheduler_config})
2 changes: 1 addition & 1 deletion tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_chat_compare_statefull_vs_text_history(model_descr, generation_config:
# HF in chat scenario does not add special tokens, but openvino tokenizer by default is converted with add_special_tokens=True.
# Need to regenerate openvino_tokenizer/detokenizer.
model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'), add_special_tokens=False)
pipe_with_kv_cache = ov_genai.LLMPipeline(str(path), device, **{"ENABLE_MMAP": False})
pipe_with_kv_cache = ov_genai.LLMPipeline(str(path), device, config={"ENABLE_MMAP": False})

pipe_with_kv_cache.start_chat()
for question in quenstions:
Expand Down
8 changes: 4 additions & 4 deletions tests/python_tests/test_whisper_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def read_whisper_model(params, **tokenizer_kwargs):
path,
opt_pipe,
ov_genai.WhisperPipeline(
str(path), device="CPU", **{"ENABLE_MMAP": False}
str(path), device="CPU", config={"ENABLE_MMAP": False}
),
)

Expand Down Expand Up @@ -201,7 +201,7 @@ def test_whisper_constructors(model_descr, test_sample):
expected = opt_pipe(test_sample)["text"]

genai_result = ov_genai.WhisperPipeline(
str(path), device="CPU", **{"ENABLE_MMAP": False}
str(path), device="CPU", config={"ENABLE_MMAP": False}
).generate(test_sample)

assert genai_result.texts[0] == expected
Expand All @@ -213,7 +213,7 @@ def test_whisper_constructors(model_descr, test_sample):
tokenizer = ov_genai.Tokenizer(str(path))

genai_result = ov_genai.WhisperPipeline(
str(path), tokenizer=tokenizer, device="CPU", **{"ENABLE_MMAP": False}
str(path), tokenizer=tokenizer, device="CPU", config={"ENABLE_MMAP": False}
).generate(test_sample)

assert genai_result.texts[0] == expected
Expand All @@ -237,7 +237,7 @@ def test_max_new_tokens(model_descr, test_sample):
tokenizer = ov_genai.Tokenizer(str(path))

genai_pipeline = ov_genai.WhisperPipeline(
str(path), tokenizer=tokenizer, device="CPU", **{"ENABLE_MMAP": False}
str(path), tokenizer=tokenizer, device="CPU", config={"ENABLE_MMAP": False}
)
config = genai_pipeline.get_generation_config()
config.max_new_tokens = 30
Expand Down
2 changes: 1 addition & 1 deletion tools/llm_bench/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def create_genai_text_gen_model(model_path, device, ov_config, **kwargs):
setattr(scheduler_config, param, value)
ov_config["scheduler_config"] = scheduler_config
start = time.perf_counter()
llm_pipe = openvino_genai.LLMPipeline(str(model_path), device.upper(), **ov_config)
llm_pipe = openvino_genai.LLMPipeline(str(model_path), device.upper(), ov_config)
end = time.perf_counter()
log.info(f'Pipeline initialization time: {end - start:.2f}s')

Expand Down

0 comments on commit 7bca9d2

Please sign in to comment.