Skip to content

Commit

Permalink
[PT FE] Add support for GFPGAN model (openvinotoolkit#21371)
Browse files Browse the repository at this point in the history
* [PT FE] Add support for GFPGAN model

* Remove logs

* Fix codestyle

* Add support for aten::normal
  • Loading branch information
mvafin authored Nov 29, 2023
1 parent cb5377f commit 007b6fd
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 7 deletions.
102 changes: 95 additions & 7 deletions src/frontends/pytorch/src/op/rand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ namespace op {
using namespace ov::op;

namespace {
OutputVector make_random_normal(const NodeContext& context, Output<Node> sizes, element::Type target_type) {
OutputVector make_random_normal(const NodeContext& context,
const Output<Node>& sizes,
element::Type target_type,
const Output<Node>& scale_const,
const Output<Node>& mean_const) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint64_t> distrib(0, 9999);
Expand Down Expand Up @@ -57,8 +61,6 @@ OutputVector make_random_normal(const NodeContext& context, Output<Node> sizes,
auto multiply_two_pi_uniform_2 = context.mark_node(std::make_shared<v1::Multiply>(multiply_two_pi, uniform_2));
auto cos = context.mark_node(std::make_shared<v0::Cos>(multiply_two_pi_uniform_2));

auto scale_const = context.mark_node(v0::Constant::create(target_type, Shape{1}, {1}));
auto mean_const = context.mark_node(v0::Constant::create(target_type, Shape{1}, {0}));
auto sqrt_x_cos = context.mark_node(std::make_shared<v1::Multiply>(sqrt, cos));
auto product = context.mark_node(std::make_shared<v1::Multiply>(scale_const, sqrt_x_cos));
auto sum = context.mark_node(std::make_shared<v1::Add>(product, mean_const));
Expand Down Expand Up @@ -180,7 +182,9 @@ OutputVector translate_randn(const NodeContext& context) {
// aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
// aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
if (context.get_input_size() == 2 || context.get_input_size() == 3) {
auto res = make_random_normal(context, sizes, dtype);
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
auto res = make_random_normal(context, sizes, dtype, scale, mean);
context.mutate_input(out_id, res[0]);
return res;
}
Expand Down Expand Up @@ -210,7 +214,9 @@ OutputVector translate_randn(const NodeContext& context) {
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
}
}
auto res = make_random_normal(context, sizes, dtype);
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
auto res = make_random_normal(context, sizes, dtype, scale, mean);
if (!dtype_applied) {
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], convert_like_out));
}
Expand All @@ -226,7 +232,9 @@ OutputVector translate_randn_like(const NodeContext& context) {
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
auto dtype = element::f32;
if (context.get_input_size() == 3) {
auto res = make_random_normal(context, sizes, dtype);
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
auto res = make_random_normal(context, sizes, dtype, scale, mean);
context.mutate_input(2, res[0]);
return res;
}
Expand All @@ -246,7 +254,9 @@ OutputVector translate_randn_like(const NodeContext& context) {
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
}
}
auto res = make_random_normal(context, sizes, dtype);
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
auto res = make_random_normal(context, sizes, dtype, scale, mean);
if (!dtype_applied) {
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], convert_like_out));
}
Expand Down Expand Up @@ -283,6 +293,84 @@ OutputVector translate_randint(const NodeContext& context) {
return {res};
};

OutputVector translate_normal_(const NodeContext& context) {
// aten::normal_(Tensor(a!) self, float mean=0., float std=1., *, Generator? generator=None) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto inp_tensor = context.get_input(0);
auto mean = context.get_input(1);
auto std = context.get_input(2);
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
auto dtype = element::f32;
auto res = make_random_normal(context, sizes, dtype, std, mean);
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], inp_tensor));
context.mutate_input(0, res[0]);
return res;
}

OutputVector translate_normal(const NodeContext& context) {
num_inputs_check(context, 2, 8);
auto mean = context.get_input(0);
auto std = context.get_input(1);
auto dtype = element::f32;
if (context.get_input_size() == 3 || context.get_input_size() == 4) {
// aten::normal.Tensor_float(Tensor mean, float std=1., *, Generator? generator=None) -> Tensor
// aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
// aten::normal.Tensor_float_out(Tensor mean, float std=1., *, Generator? generator=None, Tensor(a!) out) ->
// Tensor(a!)
// aten::normal.Tensor_float_out(Tensor mean, float std=1., *, Generator? generator=None, Tensor(a!)
// out) -> Tensor(a!)
// aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None,
// Tensor(a!) out) -> Tensor(a!)
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(mean, element::i32));
auto res = make_random_normal(context, sizes, dtype, std, mean);
if (!context.input_is_none(3)) {
// out
auto out = context.get_input(3);
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], out));
context.mutate_input(3, res[0]);
}
return res;
} else if (context.get_input_size() == 5) {
// aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!)
// out) -> Tensor(a!)
auto sizes = context.get_input(2);
auto res = make_random_normal(context, sizes, dtype, std, mean);
if (!context.input_is_none(4)) {
// out
auto out = context.get_input(4);
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], out));
context.mutate_input(4, res[0]);
}
return res;
} else if (context.get_input_size() == 8) {
// aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType?
// dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
auto sizes = context.get_input(2);
Output<Node> convert_like_out;
bool dtype_applied = true;
if (!context.input_is_none(4)) {
if (std::dynamic_pointer_cast<v0::Constant>(
context.get_input_from_visible_context(3).get_node_shared_ptr())) {
dtype = convert_dtype(context.const_input<int64_t>(4));
} else if (const auto& fw_node = cast_fw_node(context.get_input(3).get_node_shared_ptr(), "prim::dtype")) {
convert_like_out = fw_node->input_value(0);
dtype_applied = false;
} else {
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
}
}
auto res = make_random_normal(context, sizes, dtype, std, mean);
if (!dtype_applied) {
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], convert_like_out));
}
return res;
} else {
FRONT_END_OP_CONVERSION_CHECK(false,
"Unsupported number of inputs to aten::normal operation: ",
context.get_input_size());
}
}

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ OP_CONVERTER(translate_new_zeros);
OP_CONVERTER(translate_nms);
OP_CONVERTER(translate_nonzero);
OP_CONVERTER(translate_norm);
OP_CONVERTER(translate_normal);
OP_CONVERTER(translate_normal_);
OP_CONVERTER(translate_not);
OP_CONVERTER(translate_numel);
OP_CONVERTER(translate_one_hot);
Expand Down Expand Up @@ -438,6 +440,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::new_zeros", op::translate_new_zeros},
{"aten::nonzero", op::translate_nonzero},
{"aten::norm", op::translate_norm},
{"aten::normal", op::translate_normal},
{"aten::normal_", op::translate_normal_},
{"aten::numel", op::translate_numel},
{"aten::numpy_T", op::translate_t},
{"aten::one_hot", op::translate_one_hot},
Expand Down
90 changes: 90 additions & 0 deletions tests/layer_tests/pytorch_tests/test_rand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class TestInplaceNormal(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)

def create_model(self, mean, std):
class aten_normal(torch.nn.Module):
def __init__(self, mean, std):
super(aten_normal, self).__init__()
self.mean = mean
self.std = std

def forward(self, x):
x = x.to(torch.float32)
return x.normal_(mean=self.mean, std=self.std), x

return aten_normal(mean, std), None, "aten::normal_"

@pytest.mark.parametrize("mean,std", [(0., 1.), (5., 20.)])
@pytest.mark.nightly
@pytest.mark.precommit
def test_inplace_normal(self, mean, std, ie_device, precision, ir_version):
self._test(*self.create_model(mean, std),
ie_device, precision, ir_version, custom_eps=1e30)


class TestNormal(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
if isinstance(self.inputs, list):
return (np.random.randn(*self.inputs).astype(np.float32),)
return self.inputs

class aten_normal1(torch.nn.Module):
def forward(self, mean, std):
return torch.normal(mean, std)

class aten_normal2(torch.nn.Module):
def forward(self, mean, std):
x = torch.empty_like(mean, dtype=torch.float32)
return torch.normal(mean, std, out=x), x

class aten_normal3(torch.nn.Module):
def forward(self, mean):
return torch.normal(mean)

class aten_normal4(torch.nn.Module):
def forward(self, mean):
x = torch.empty_like(mean, dtype=torch.float32)
return torch.normal(mean, out=x), x

class aten_normal5(torch.nn.Module):
def forward(self, mean):
x = torch.empty_like(mean, dtype=torch.float32)
return torch.normal(mean, 2., out=x), x

class aten_normal6(torch.nn.Module):
def forward(self, x):
x = x.to(torch.float32)
return torch.normal(0., 1., x.shape)

class aten_normal7(torch.nn.Module):
def forward(self, x):
x = x.to(torch.float32)
return torch.normal(0., 1., x.shape, out=x), x

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("model,inputs", [
(aten_normal1(), (torch.arange(1., 11.).numpy(), torch.arange(1, 0, -0.1).numpy())),
(aten_normal2(), (torch.arange(1., 11.).numpy(), torch.arange(1, 0, -0.1).numpy())),
(aten_normal3(), (torch.arange(1., 11.).numpy(),)),
(aten_normal4(), (torch.arange(1., 11.).numpy(),)),
(aten_normal5(), (torch.arange(1., 11.).numpy(),)),
(aten_normal6(), [1, 3, 224, 224]),
(aten_normal7(), [1, 3, 224, 224]),
])
def test_inplace_normal(self, model, inputs, ie_device, precision, ir_version):
self.inputs = inputs
self._test(model, None, "aten::normal",
ie_device, precision, ir_version, custom_eps=1e30)
2 changes: 2 additions & 0 deletions tests/model_hub_tests/torch_tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ protobuf
soundfile
pandas
super-image
basicsr
facexlib
81 changes: 81 additions & 0 deletions tests/model_hub_tests/torch_tests/test_gfpgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import subprocess
import sys
import tempfile

import pytest
import torch

from torch_utils import TestTorchConvertModel
from openvino import convert_model
import numpy as np

# To make tests reproducible we seed the random generator
torch.manual_seed(0)


class TestGFPGANConvertModel(TestTorchConvertModel):
def setup_class(self):
self.repo_dir = tempfile.TemporaryDirectory()
os.system(
f"git clone https://github.com/TencentARC/GFPGAN.git {self.repo_dir.name}")
subprocess.check_call(
["git", "checkout", "bc5a5deb95a4a9653851177985d617af1b9bfa8b"], cwd=self.repo_dir.name)
checkpoint_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
subprocess.check_call(
["wget", "-nv", checkpoint_url], cwd=self.repo_dir.name)

def load_model(self, model_name, model_link):
sys.path.append(self.repo_dir.name)
from gfpgan import GFPGANer

filename = os.path.join(self.repo_dir.name, 'GFPGANv1.3.pth')
arch = 'clean'
channel_multiplier = 2
restorer = GFPGANer(
model_path=filename,
upscale=2,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)

self.example = (torch.randn(1, 3, 512, 512),)
self.inputs = (torch.randn(1, 3, 512, 512),)
return restorer.gfpgan

def convert_model(self, model_obj):
ov_model = convert_model(
model_obj, example_input=self.example, input=[1, 3, 512, 512], verbose=True)
return ov_model

def compare_results(self, fw_outputs, ov_outputs):
assert len(fw_outputs) == len(ov_outputs), \
"Different number of outputs between framework and OpenVINO:" \
" {} vs. {}".format(len(fw_outputs), len(ov_outputs))

fw_eps = 5e-2
is_ok = True
for i in range(len(ov_outputs)):
cur_fw_res = fw_outputs[i]
cur_ov_res = ov_outputs[i]
try:
np.testing.assert_allclose(
cur_ov_res, cur_fw_res, fw_eps, fw_eps)
except AssertionError as e:
print(e)
# The model has aten::normal_ operation which produce random numbers.
# Cannot reliably validate the output 0
if i != 0:
is_ok = False
assert is_ok, "Accuracy validation failed"

def teardown_class(self):
# remove all downloaded files from cache
self.repo_dir.cleanup()

@pytest.mark.nightly
def test_convert_model(self, ie_device):
self.run("GFPGAN", None, ie_device)

0 comments on commit 007b6fd

Please sign in to comment.