Skip to content

Commit

Permalink
Add accuracy test
Browse files Browse the repository at this point in the history
Signed-off-by: Akhil Goel <[email protected]>
  • Loading branch information
akhilg-nv committed Aug 29, 2024
1 parent be2c94e commit a8aacaf
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 27 deletions.
8 changes: 4 additions & 4 deletions tripy/examples/diffusion/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import numpy as np

from transformers import CLIPTokenizer
from model import CLIPConfig, StableDiffusion, get_alphas_cumprod
from weight_loader import load_from_diffusers
from examples.diffusion.model import CLIPConfig, StableDiffusion, get_alphas_cumprod
from examples.diffusion.weight_loader import load_from_diffusers
import tripy as tp


Expand Down Expand Up @@ -266,7 +266,8 @@ def print_summary(denoising_steps, times):


# TODO: Add torch compilation modes
# TODO: Add fp16
# TODO: Add fp16 support
# TODO: Add Timing context
def main():
default_prompt = "a horse sized cat eating a bagel"
parser = argparse.ArgumentParser(
Expand All @@ -275,7 +276,6 @@ def main():
parser.add_argument("--steps", type=int, default=10, help="Number of denoising steps in diffusion")
parser.add_argument("--prompt", type=str, default=default_prompt, help="Phrase to render")
parser.add_argument("--out", type=str, default=os.path.join("output", "rendered.png"), help="Output filename")
parser.add_argument("--noshow", action="store_true", help="Don't show the image")
parser.add_argument("--fp16", action="store_true", help="Cast the weights to float16")
parser.add_argument("--timing", action="store_true", help="Print timing per step")
parser.add_argument("--seed", type=int, help="Set the random latent seed")
Expand Down
1 change: 1 addition & 0 deletions tripy/examples/diffusion/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
diffusers==0.26.3
transformers==4.33.1
scikit-image
22 changes: 0 additions & 22 deletions tripy/examples/diffusion/test_acc.py

This file was deleted.

44 changes: 44 additions & 0 deletions tripy/tests/test_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch
import numpy as np
from argparse import Namespace
from skimage.metrics import structural_similarity

from examples.diffusion.example import tripy_diffusion, hf_diffusion


# Utility for debugging hidden states in model via floating-point comparison
def check_equal(tp_array, torch_tensor, dtype=torch.float32, debug=False):
if debug:
a = torch.from_dlpack(tp_array).to(dtype)
b = torch_tensor.to(dtype)
diff = a - b
print(f"tripy output shape: {a.shape}, torch output shape: {b.shape}")

max_abs_diff = torch.max(torch.abs(diff))
print(f"Maximum absolute difference: {max_abs_diff}\n")

# Add small epsilon to denominator to avoid division by 0
eps = 1e-8
rel_diff = torch.abs(diff) / (torch.abs(b) + eps)
max_rel_diff = torch.max(rel_diff)
print(f"Maximum relative difference: {max_rel_diff}\n")

assert torch.allclose(torch.from_dlpack(tp_array).to(dtype), torch_tensor.to(dtype)), f"\nTP Array:\n {tp_array} \n!= Torch Tensor:\n {torch_tensor}"

@pytest.mark.l1
class TestConvolution:
def test_ssim(self):
args = Namespace(steps=50, prompt='a beautiful photograph of Mt. Fuji during cherry blossom', out='output/rendered.png', fp16=False, seed=100, guidance=7.5, torch_inference=False)
tp_img, _ = tripy_diffusion(args)
print(f"first: {tp_img}")
tp_img = np.array(tp_img.convert('L'))
print(f"second: {tp_img}")
torch_img, _ = hf_diffusion(args)
print(f"third: {torch_img}")
torch_img = np.array(torch_img.convert('L'))
print(f"fourth: {torch_img}")

ssim = structural_similarity(tp_img, torch_img)
print(f"SSIM IS: {ssim}")
assert ssim >= 0.85, "Structural Similarity score expected >= 0.85 but got {ssim}"
2 changes: 1 addition & 1 deletion tripy/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __str__(self):
return os.path.relpath(self.path, EXAMPLES_ROOT)


EXAMPLES = [Example(["nanogpt"])]
EXAMPLES = [Example(["nanogpt"]), Example(["diffusion"])]


@pytest.mark.l1
Expand Down

0 comments on commit a8aacaf

Please sign in to comment.