Skip to content

Commit

Permalink
vsncnn/vs_ncnn.cpp: add params use_ncnn_network_format and `ncnn_sh…
Browse files Browse the repository at this point in the history
…ape_hint`

#39
  • Loading branch information
WolframRhodium committed Apr 12, 2023
1 parent b4f9a70 commit 7d967b9
Showing 1 changed file with 141 additions and 47 deletions.
188 changes: 141 additions & 47 deletions vsncnn/vs_ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ extern std::variant<std::string, ONNX_NAMESPACE::ModelProto> loadONNX(
) noexcept;


#ifdef _WIN32
#include <locale>
#include <codecvt>
static inline std::wstring translateName(const char *name) noexcept {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(name);
}
#else
#define translateName(n) (n)
#endif


static const VSPlugin * myself = nullptr;


Expand Down Expand Up @@ -497,6 +509,31 @@ static void VS_CC vsNcnnCreate(
path_is_serialization = false;
}

bool use_ncnn_network_format = !!vsapi->propGetInt(in, "use_ncnn_network_format", 0, &error);
if (error) {
use_ncnn_network_format = false;
}
if (use_ncnn_network_format && path_is_serialization) {
return set_error(
"\"use_ncnn_network_format\" and \"path_is_serialization\" "
"should not be enabled at the same time"
);
}

// ncnn related code
if (auto device = ncnn::get_gpu_device(device_id); device != nullptr) {
d->device = device;
} else {
return set_error("get_gpu_device failed");
}

d->net.opt.num_threads = 1;
d->net.opt.use_vulkan_compute = true;
d->net.opt.use_fp16_packed = d->fp16;
d->net.opt.use_fp16_storage = d->fp16;
d->net.opt.use_int8_storage = false;
d->net.set_vulkan_device(d->device);

std::string_view path_view;
std::string path;
if (path_is_serialization) {
Expand All @@ -518,22 +555,109 @@ static void VS_CC vsNcnnCreate(
path_view = path;
}

auto result = loadONNX(path_view, tile_w, tile_h, path_is_serialization);
if (std::holds_alternative<std::string>(result)) {
return set_error(std::get<std::string>(result));
}
if (use_ncnn_network_format) {
if (vsapi->propNumElements(in, "ncnn_shape_hint") != 6) {
return set_error("\"ncnn_shape_hint\" must be specified as [in_c, in_h, in_w, out_c, out_h, out_w]");
}

auto ncnn_shape_hint = vsapi->propGetIntArray(in, "ncnn_shape_hint", nullptr);
d->in_tile_c = int64ToIntS(ncnn_shape_hint[0]);
d->in_tile_h = int64ToIntS(ncnn_shape_hint[1]);
d->in_tile_w = int64ToIntS(ncnn_shape_hint[2]);
d->out_tile_c = int64ToIntS(ncnn_shape_hint[3]);
d->out_tile_h = int64ToIntS(ncnn_shape_hint[4]);
d->out_tile_w = int64ToIntS(ncnn_shape_hint[5]);

auto dot_index = static_cast<int>(path_view.size() - 1);
while (dot_index >= 0 && path_view[dot_index] != '.') {
dot_index--;
}
if (dot_index < 0) {
return set_error("invalid \"network_path\"");
}

auto temp = std::string{path_view.substr(0, dot_index + 1)} + "param";
char * ncnn_bin;

{
std::ifstream param_stream(
translateName(temp.c_str()),
std::ios::binary | std::ios::ate
);

if (!param_stream.good()) {
return set_error("open param failed");
}

auto size = param_stream.tellg();
ncnn_bin = reinterpret_cast<char *>(vs_aligned_malloc(size, sizeof(void *)));
param_stream.seekg(0);
param_stream.read(ncnn_bin, size);
}

if (d->net.load_param_mem(ncnn_bin) != 0) {
vs_aligned_free(ncnn_bin);
return set_error("load param failed");
}

vs_aligned_free(ncnn_bin);

temp = std::string{path_view.substr(0, dot_index + 1)} + "bin";

{
std::ifstream bin_stream(
translateName(temp.c_str()),
std::ios::binary | std::ios::ate
);

if (!bin_stream.good()) {
return set_error("open weights failed");
}

auto size = bin_stream.tellg();
ncnn_bin = reinterpret_cast<char *>(vs_aligned_malloc(size, sizeof(void *)));
bin_stream.seekg(0);
bin_stream.read(ncnn_bin, size);

auto onnx_model = std::move(std::get<ONNX_NAMESPACE::ModelProto>(result));
{
const auto & input_shape = onnx_model.graph().input(0).type().tensor_type().shape();
d->in_tile_c = int64ToIntS(input_shape.dim(1).dim_value());
d->in_tile_h = int64ToIntS(input_shape.dim(2).dim_value());
d->in_tile_w = int64ToIntS(input_shape.dim(3).dim_value());

const auto & output_shape = onnx_model.graph().output(0).type().tensor_type().shape();
d->out_tile_c = int64ToIntS(output_shape.dim(1).dim_value());
d->out_tile_h = int64ToIntS(output_shape.dim(2).dim_value());
d->out_tile_w = int64ToIntS(output_shape.dim(3).dim_value());
d->net.load_model(reinterpret_cast<const unsigned char *>(ncnn_bin));
}

vs_aligned_free(ncnn_bin);
} else {
auto result = loadONNX(path_view, tile_w, tile_h, path_is_serialization);
if (std::holds_alternative<std::string>(result)) {
return set_error(std::get<std::string>(result));
}

auto onnx_model = std::move(std::get<ONNX_NAMESPACE::ModelProto>(result));
{
const auto & input_shape = onnx_model.graph().input(0).type().tensor_type().shape();
d->in_tile_c = int64ToIntS(input_shape.dim(1).dim_value());
d->in_tile_h = int64ToIntS(input_shape.dim(2).dim_value());
d->in_tile_w = int64ToIntS(input_shape.dim(3).dim_value());

const auto & output_shape = onnx_model.graph().output(0).type().tensor_type().shape();
d->out_tile_c = int64ToIntS(output_shape.dim(1).dim_value());
d->out_tile_h = int64ToIntS(output_shape.dim(2).dim_value());
d->out_tile_w = int64ToIntS(output_shape.dim(3).dim_value());
}

auto ncnn_result = onnx2ncnn(onnx_model);
if (!ncnn_result.has_value()) {
return set_error("onnx2ncnn failed");
}

const auto & [ncnn_param, ncnn_model_bin] = ncnn_result.value();

if (d->net.load_param_mem(ncnn_param) != 0) {
vs_aligned_free(ncnn_param);
vs_aligned_free(ncnn_model_bin);
return set_error("load param failed");
}
vs_aligned_free(ncnn_param);
// TODO: here returns the number of bytes read successfully
d->net.load_model(ncnn_model_bin);
vs_aligned_free(ncnn_model_bin);
}

d->out_vi = std::make_unique<VSVideoInfo>(*in_vis.front()); // mutable
Expand All @@ -546,38 +670,6 @@ static void VS_CC vsNcnnCreate(
d->out_vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core);
}

auto ncnn_result = onnx2ncnn(onnx_model);
if (!ncnn_result.has_value()) {
return set_error("onnx2ncnn failed");
}

const auto & [ncnn_param, ncnn_model_bin] = ncnn_result.value();

// ncnn related code
if (auto device = ncnn::get_gpu_device(device_id); device != nullptr) {
d->device = device;
} else {
vs_aligned_free(ncnn_param);
vs_aligned_free(ncnn_model_bin);
return set_error("get_gpu_device failed");
}

d->net.opt.num_threads = 1;
d->net.opt.use_vulkan_compute = true;
d->net.opt.use_fp16_packed = d->fp16;
d->net.opt.use_fp16_storage = d->fp16;
d->net.opt.use_int8_storage = false;
d->net.set_vulkan_device(d->device);
if (d->net.load_param_mem(ncnn_param) != 0) {
vs_aligned_free(ncnn_param);
vs_aligned_free(ncnn_model_bin);
return set_error("load param failed");
}
vs_aligned_free(ncnn_param);
// TODO: here returns the number of bytes read successfully
d->net.load_model(ncnn_model_bin);
vs_aligned_free(ncnn_model_bin);

d->input_index = d->net.input_indexes().front();
d->output_index = d->net.output_indexes().front();

Expand Down Expand Up @@ -634,6 +726,8 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
"builtindir:data:opt;"
"fp16:int:opt;"
"path_is_serialization:int:opt;"
"use_ncnn_network_format:int:opt;"
"ncnn_shape_hint:int[]:opt;"
, vsNcnnCreate,
nullptr,
plugin
Expand Down

0 comments on commit 7d967b9

Please sign in to comment.