forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PT FE] Add support for GFPGAN model (openvinotoolkit#21371)
* [PT FE] Add support for GFPGAN model * Remove logs * Fix codestyle * Add support for aten::normal
- Loading branch information
Showing
5 changed files
with
272 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestInplaceNormal(PytorchLayerTest): | ||
def _prepare_input(self): | ||
import numpy as np | ||
return (np.random.randn(1, 3, 224, 224).astype(np.float32),) | ||
|
||
def create_model(self, mean, std): | ||
class aten_normal(torch.nn.Module): | ||
def __init__(self, mean, std): | ||
super(aten_normal, self).__init__() | ||
self.mean = mean | ||
self.std = std | ||
|
||
def forward(self, x): | ||
x = x.to(torch.float32) | ||
return x.normal_(mean=self.mean, std=self.std), x | ||
|
||
return aten_normal(mean, std), None, "aten::normal_" | ||
|
||
@pytest.mark.parametrize("mean,std", [(0., 1.), (5., 20.)]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_inplace_normal(self, mean, std, ie_device, precision, ir_version): | ||
self._test(*self.create_model(mean, std), | ||
ie_device, precision, ir_version, custom_eps=1e30) | ||
|
||
|
||
class TestNormal(PytorchLayerTest): | ||
def _prepare_input(self): | ||
import numpy as np | ||
if isinstance(self.inputs, list): | ||
return (np.random.randn(*self.inputs).astype(np.float32),) | ||
return self.inputs | ||
|
||
class aten_normal1(torch.nn.Module): | ||
def forward(self, mean, std): | ||
return torch.normal(mean, std) | ||
|
||
class aten_normal2(torch.nn.Module): | ||
def forward(self, mean, std): | ||
x = torch.empty_like(mean, dtype=torch.float32) | ||
return torch.normal(mean, std, out=x), x | ||
|
||
class aten_normal3(torch.nn.Module): | ||
def forward(self, mean): | ||
return torch.normal(mean) | ||
|
||
class aten_normal4(torch.nn.Module): | ||
def forward(self, mean): | ||
x = torch.empty_like(mean, dtype=torch.float32) | ||
return torch.normal(mean, out=x), x | ||
|
||
class aten_normal5(torch.nn.Module): | ||
def forward(self, mean): | ||
x = torch.empty_like(mean, dtype=torch.float32) | ||
return torch.normal(mean, 2., out=x), x | ||
|
||
class aten_normal6(torch.nn.Module): | ||
def forward(self, x): | ||
x = x.to(torch.float32) | ||
return torch.normal(0., 1., x.shape) | ||
|
||
class aten_normal7(torch.nn.Module): | ||
def forward(self, x): | ||
x = x.to(torch.float32) | ||
return torch.normal(0., 1., x.shape, out=x), x | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("model,inputs", [ | ||
(aten_normal1(), (torch.arange(1., 11.).numpy(), torch.arange(1, 0, -0.1).numpy())), | ||
(aten_normal2(), (torch.arange(1., 11.).numpy(), torch.arange(1, 0, -0.1).numpy())), | ||
(aten_normal3(), (torch.arange(1., 11.).numpy(),)), | ||
(aten_normal4(), (torch.arange(1., 11.).numpy(),)), | ||
(aten_normal5(), (torch.arange(1., 11.).numpy(),)), | ||
(aten_normal6(), [1, 3, 224, 224]), | ||
(aten_normal7(), [1, 3, 224, 224]), | ||
]) | ||
def test_inplace_normal(self, model, inputs, ie_device, precision, ir_version): | ||
self.inputs = inputs | ||
self._test(model, None, "aten::normal", | ||
ie_device, precision, ir_version, custom_eps=1e30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,5 @@ protobuf | |
soundfile | ||
pandas | ||
super-image | ||
basicsr | ||
facexlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import subprocess | ||
import sys | ||
import tempfile | ||
|
||
import pytest | ||
import torch | ||
|
||
from torch_utils import TestTorchConvertModel | ||
from openvino import convert_model | ||
import numpy as np | ||
|
||
# To make tests reproducible we seed the random generator | ||
torch.manual_seed(0) | ||
|
||
|
||
class TestGFPGANConvertModel(TestTorchConvertModel): | ||
def setup_class(self): | ||
self.repo_dir = tempfile.TemporaryDirectory() | ||
os.system( | ||
f"git clone https://github.com/TencentARC/GFPGAN.git {self.repo_dir.name}") | ||
subprocess.check_call( | ||
["git", "checkout", "bc5a5deb95a4a9653851177985d617af1b9bfa8b"], cwd=self.repo_dir.name) | ||
checkpoint_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth" | ||
subprocess.check_call( | ||
["wget", "-nv", checkpoint_url], cwd=self.repo_dir.name) | ||
|
||
def load_model(self, model_name, model_link): | ||
sys.path.append(self.repo_dir.name) | ||
from gfpgan import GFPGANer | ||
|
||
filename = os.path.join(self.repo_dir.name, 'GFPGANv1.3.pth') | ||
arch = 'clean' | ||
channel_multiplier = 2 | ||
restorer = GFPGANer( | ||
model_path=filename, | ||
upscale=2, | ||
arch=arch, | ||
channel_multiplier=channel_multiplier, | ||
bg_upsampler=None) | ||
|
||
self.example = (torch.randn(1, 3, 512, 512),) | ||
self.inputs = (torch.randn(1, 3, 512, 512),) | ||
return restorer.gfpgan | ||
|
||
def convert_model(self, model_obj): | ||
ov_model = convert_model( | ||
model_obj, example_input=self.example, input=[1, 3, 512, 512], verbose=True) | ||
return ov_model | ||
|
||
def compare_results(self, fw_outputs, ov_outputs): | ||
assert len(fw_outputs) == len(ov_outputs), \ | ||
"Different number of outputs between framework and OpenVINO:" \ | ||
" {} vs. {}".format(len(fw_outputs), len(ov_outputs)) | ||
|
||
fw_eps = 5e-2 | ||
is_ok = True | ||
for i in range(len(ov_outputs)): | ||
cur_fw_res = fw_outputs[i] | ||
cur_ov_res = ov_outputs[i] | ||
try: | ||
np.testing.assert_allclose( | ||
cur_ov_res, cur_fw_res, fw_eps, fw_eps) | ||
except AssertionError as e: | ||
print(e) | ||
# The model has aten::normal_ operation which produce random numbers. | ||
# Cannot reliably validate the output 0 | ||
if i != 0: | ||
is_ok = False | ||
assert is_ok, "Accuracy validation failed" | ||
|
||
def teardown_class(self): | ||
# remove all downloaded files from cache | ||
self.repo_dir.cleanup() | ||
|
||
@pytest.mark.nightly | ||
def test_convert_model(self, ie_device): | ||
self.run("GFPGAN", None, ie_device) |