Skip to content

Commit

Permalink
Update Sample.
Browse files Browse the repository at this point in the history
  • Loading branch information
quic-zhanweiw committed Jan 15, 2025
1 parent 2b2ef56 commit 9b56a68
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions samples/python/real_esrgan_x4plus/real_esrgan_x4plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

####################################################################

MODEL_ID = "m7qk01okn"
MODEL_ID = "mnz1l2exq"
MODEL_NAME = "real_esrgan_x4plus"
MODEL_HELP_URL = "https://github.com/quic/ai-engine-direct-helper/tree/main/samples/python/" + MODEL_NAME + "#" + MODEL_NAME + "-qnn-models"
IMAGE_SIZE = 512
Expand All @@ -46,15 +46,15 @@ def preprocess_PIL_image(image: Image) -> torch.Tensor:
transforms.CenterCrop(IMAGE_SIZE),
transforms.PILToTensor()])
img: torch.Tensor = transform(image) # type: ignore
img = img.float().unsqueeze(0) / 255.0 # int 0 - 255 to float 0.0 - 1.0
img = img.float() / 255.0 # int 0 - 255 to float 0.0 - 1.0
return img

def torch_tensor_to_PIL_image(data: torch.Tensor) -> Image:
"""
Convert a Torch tensor (dtype float32) with range [0, 1] and shape CHW into PIL image CHW
"""
out = torch.clip(data, min=0.0, max=1.0)
np_out = (out.permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
np_out = (out.detach().numpy() * 255).astype(np.uint8)
return ImageFromArray(np_out)

# RealESRGan class which inherited from the class QNNContext.
Expand Down Expand Up @@ -91,6 +91,7 @@ def Inference(input_image_path, output_image_path):
# Read and preprocess the image.
image = Image.open(input_image_path)
image = preprocess_PIL_image(image).numpy()
image = np.transpose(image, (1, 2, 0)) # CHW -> HWC

# Burst the HTP.
PerfProfile.SetPerfProfileGlobal(PerfProfile.BURST)
Expand All @@ -101,8 +102,9 @@ def Inference(input_image_path, output_image_path):
# Reset the HTP.
PerfProfile.RelPerfProfileGlobal()

# show & save the result
output_image = torch.from_numpy(output_image)
output_image = output_image.reshape(3, IMAGE_SIZE * 4, IMAGE_SIZE * 4)
output_image = output_image.reshape(IMAGE_SIZE * 4, IMAGE_SIZE * 4, 3)
output_image = torch.unsqueeze(output_image, 0)
output_image = [torch_tensor_to_PIL_image(img) for img in output_image]
image_buffer = output_image[0]
Expand Down

0 comments on commit 9b56a68

Please sign in to comment.