diff --git a/tripy/examples/diffusion/example.py b/tripy/examples/diffusion/example.py index 93a17f273..0bba0bff0 100644 --- a/tripy/examples/diffusion/example.py +++ b/tripy/examples/diffusion/example.py @@ -1,4 +1,4 @@ -import argparse, tempfile +import argparse, os from tqdm import tqdm from pathlib import Path from PIL import Image @@ -8,45 +8,36 @@ import cupy as cp import numpy as np -from model import ClipTokenizer, StableDiffusion +from model import ClipTokenizer, StableDiffusion, get_alphas_cumprod from weight_loader import load_from_diffusers import tripy as tp -def tripy_diffusion(args): - model = StableDiffusion() - load_from_diffusers(model, tp.float32, debug=True) - run_start_time = time.perf_counter() +def compile_model(model, inputs, verbose=False): + if verbose: + print(f"Compiling {model.__class__.__name__}...", end=' ') + compile_start_time = time.perf_counter() - # Run through CLIP to get context - tokenizer = ClipTokenizer() - prompt = tp.Tensor([tokenizer.encode(args.prompt)]) - print(f"Got tokenized prompt.") - unconditional_prompt = tp.Tensor([tokenizer.encode("")]) - print(f"Got unconditional tokenized prompt.") + compiler = tp.Compiler(model) + compiled_model = compiler.compile(*inputs) + + if verbose: + compile_end_time = time.perf_counter() + print(f"took {compile_end_time - compile_start_time} seconds.") + + return compiled_model - print("Compiling CLIP model...") - clip_compile_start_time = time.perf_counter() - clip_compiler = tp.Compiler(model.cond_stage_model.transformer.text_model) - clip_text_model = clip_compiler.compile(tp.InputInfo((1, 77), dtype=tp.int32)) - clip_compile_end_time = time.perf_counter() - print(f"Compilation of CLIP took {clip_compile_end_time - clip_compile_start_time} seconds.") - print("Getting CLIP context...") - clip_run_start = time.perf_counter() - context = clip_text_model(prompt) - unconditional_context = clip_text_model(unconditional_prompt) - clip_run_end = time.perf_counter() - print(f"Got CLIP conditional and unconditional context. Inference took {clip_run_end - clip_run_start} seconds.") +def compile_clip(model, verbose=False): + inputs = (tp.InputInfo((1, 77), dtype=tp.int32),) + return compile_model(model, inputs, verbose=verbose) - # Backbone of diffusion - the UNet - print("Compiling UNet...") - unet_compile_start_time = time.perf_counter() - compiler = tp.Compiler(model) + +def compile_unet(model, verbose=False): unconditional_context_shape = (1, 77, 768) conditional_context_shape = (1, 77, 768) latent_shape = (1, 4, 64, 64) - compiled_model = compiler.compile( + inputs = ( tp.InputInfo(unconditional_context_shape, dtype=tp.float32), tp.InputInfo(conditional_context_shape, dtype=tp.float32), tp.InputInfo(latent_shape, dtype=tp.float32), @@ -55,56 +46,192 @@ def tripy_diffusion(args): tp.InputInfo((1,), dtype=tp.float32), tp.InputInfo((1,), dtype=tp.float32), ) - unet_compile_end_time = time.perf_counter() - print(f"Compilation of UNet took {unet_compile_end_time - unet_compile_start_time} seconds.") + return compile_model(model, inputs, verbose=verbose) + + +def compile_vae(model, verbose=False): + inputs = (tp.InputInfo((1, 4, 64, 64), dtype=tp.float32),) + return compile_model(model, inputs, verbose=verbose) + + +# def compile_CLIP(model, verbose=False): +# if verbose: +# print("Compiling CLIP model...") +# clip_compile_start_time = time.perf_counter() + +# clip_compiler = tp.Compiler(model) +# compiled_clip = clip_compiler.compile(tp.InputInfo((1, 77), dtype=tp.int32)) + +# if verbose: +# clip_compile_end_time = time.perf_counter() +# print(f"Compilation of CLIP took {clip_compile_end_time - clip_compile_start_time} seconds.") - timesteps = list(range(1, 1000, 1000 // args.steps)) - print(f"Running for {timesteps} timesteps.") - alphas = model.alphas_cumprod[tp.Tensor(timesteps)] +# return compiled_clip + + +# def compile_unet(model, verbose=False): +# if verbose: +# print("Compiling UNet...") +# unet_compile_start_time = time.perf_counter() + +# compiler = tp.Compiler(model) +# unconditional_context_shape = (1, 77, 768) +# conditional_context_shape = (1, 77, 768) +# latent_shape = (1, 4, 64, 64) +# compiled_model = compiler.compile( +# tp.InputInfo(unconditional_context_shape, dtype=tp.float32), +# tp.InputInfo(conditional_context_shape, dtype=tp.float32), +# tp.InputInfo(latent_shape, dtype=tp.float32), +# tp.InputInfo((1,), dtype=tp.float32), +# tp.InputInfo((1,), dtype=tp.float32), +# tp.InputInfo((1,), dtype=tp.float32), +# tp.InputInfo((1,), dtype=tp.float32), +# ) + +# if verbose: +# unet_compile_end_time = time.perf_counter() +# print(f"Compilation of UNet took {unet_compile_end_time - unet_compile_start_time} seconds.") + +# return compiled_model + + +def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance): + timesteps = list(range(1, 1000, 1000 // steps))[::-1] + # print(f"t: {timesteps}") + alphas = get_alphas_cumprod()[tp.Tensor(timesteps)] alphas_prev = tp.concatenate([tp.Tensor([1.0]), alphas[:-1]], dim=0) + # print(f"a: {alphas}") + # print(f"aP: {alphas_prev}") + + # unet_run_start = time.perf_counter() + for index, timestep in enumerate(timesteps): + tid = tp.Tensor([index]) + latent = model( + unconditional_context, + context, + latent, + tp.cast(tp.Tensor([timestep]), tp.float32), + alphas[tid], + alphas_prev[tid], + tp.Tensor([guidance]), + ) + # unet_run_end = time.perf_counter() + # print(f"Finished running diffusion. Inference took {unet_run_end - unet_run_start} seconds.") + return latent + + +def tripy_diffusion(args): + model = StableDiffusion() + load_from_diffusers(model, tp.float32, debug=True) + + run_start_time = time.perf_counter() + # if os.path.isdir("engines"): + # compiled_clip = tp.Executable.load(os.path.join("engines", "clip_executable.json")) + # compiled_unet = tp.Executable.load(os.path.join("engines", "unet_executable.json")) + # compiled_vae = tp.Executable.load(os.path.join("engines", "vae_executable.json")) + # else: + compiled_clip = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True) + compiled_unet = compile_unet(model, verbose=True) + compiled_vae = compile_vae(model.decode, verbose=True) + + # os.mkdir("engines") + # compiled_clip.save(os.path.join("engines", "clip_executable.json")) + # compiled_unet.save(os.path.join("engines", "unet_executable.json")) + # compiled_vae.save(os.path.join("engines", "vae_executable.json")) + + # Run through CLIP to get context + tokenizer = ClipTokenizer() + prompt = tp.Tensor([tokenizer.encode(args.prompt)]) + print(f"Got tokenized prompt.") + unconditional_prompt = tp.Tensor([tokenizer.encode("")]) + print(f"Got unconditional tokenized prompt.") + + print("Getting CLIP conditional and unconditional context...", end=' ') + clip_run_start = time.perf_counter() + context = compiled_clip(prompt) + unconditional_context = compiled_clip(unconditional_prompt) + clip_run_end = time.perf_counter() + print(f"took {clip_run_end - clip_run_start} seconds.") + + # Backbone of diffusion - the UNet + # start with random noise if args.seed is not None: torch.manual_seed(args.seed) torch_latent = torch.randn((1, 4, 64, 64)).to("cuda") latent = tp.Tensor(torch_latent) - def run(model, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance): - return model(unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance) - - # This is diffusion - print("Running diffusion...") - unet_run_start = time.perf_counter() - for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])): - t.set_description("idx: %1d, timestep: %3d" % (index, timestep)) - tid = tp.Tensor([index]) - latent = run( - compiled_model, + print(f"Running diffusion loop for {args.steps} steps...", end=' ') + + # compiler = tp.Compiler(run_diffusion_loop) + # unconditional_context_shape = (1, 77, 768) + # conditional_context_shape = (1, 77, 768) + # latent_shape = (1, 4, 64, 64) + # compiled_diffusion_loop = compiler.compile( + # model, + # tp.InputInfo(unconditional_context_shape, dtype=tp.float32), + # tp.InputInfo(conditional_context_shape, dtype=tp.float32), + # tp.InputInfo(latent_shape, dtype=tp.float32), + # args.steps, + # args.guidance, + # ) + + timesteps = list(range(1, 1000, 1000 // args.steps))[::-1] + alphas = get_alphas_cumprod()[tp.Tensor(timesteps)] + alphas_prev = tp.concatenate([tp.Tensor([1.0]), alphas[:-1]], dim=0) + tid = tp.Tensor([0]) + diffusion_run_start = time.perf_counter() + # latent = run_diffusion_loop(compiled_unet, unconditional_context, context, latent, args.steps, args.guidance) + latent = compiled_unet( unconditional_context, context, latent, - tp.cast(tp.Tensor([timestep]), tp.float32), + tp.cast(tp.Tensor([timesteps[0]]), tp.float32), alphas[tid], alphas_prev[tid], tp.Tensor([args.guidance]), ) - unet_run_end = time.perf_counter() - print(f"Finished running diffusion. Inference took {unet_run_end - unet_run_start} seconds.") + diffusion_run_end = time.perf_counter() + print(f"took {diffusion_run_end - diffusion_run_start} seconds.") + + #latent = run_diffusion_loop(compiled_unet, unconditional_context, context, latent, args.steps, args.guidance) + + # timesteps = list(range(1, 1000, 1000 // args.steps)) + # print(f"Running for {timesteps} timesteps.") + # alphas = model.alphas_cumprod[tp.Tensor(timesteps)] + # alphas_prev = tp.concatenate([tp.Tensor([1.0]), alphas[:-1]], dim=0) + + # def run(model, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance): + # return model(unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance) + + # # This is diffusion + # print("Running diffusion...") + # unet_run_start = time.perf_counter() + # for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])): + # t.set_description("idx: %1d, timestep: %3d" % (index, timestep)) + # tid = tp.Tensor([index]) + # latent = run( + # compiled_unet, + # unconditional_context, + # context, + # latent, + # tp.cast(tp.Tensor([timestep]), tp.float32), + # alphas[tid], + # alphas_prev[tid], + # tp.Tensor([args.guidance]), + # ) + # unet_run_end = time.perf_counter() + # print(f"Finished running diffusion. Inference took {unet_run_end - unet_run_start} seconds.") # Upsample latent space to image with autoencoder - print("Compiling VAE decoder...") - vae_compile_start_time = time.perf_counter() - vae_compiler = tp.Compiler(model.decode) - vae_decode = vae_compiler.compile(tp.InputInfo((1, 4, 64, 64), dtype=tp.float32)) - vae_compile_end_time = time.perf_counter() - print(f"Compilation took {vae_compile_end_time - vae_compile_start_time} seconds.") - - print(f"Decoding latent...") + + print(f"Decoding latent...", end=' ') vae_run_start = time.perf_counter() - x = vae_decode(latent) + x = compiled_vae(latent) # x = model.decode(latent) vae_run_end = time.perf_counter() - print(f"Finished decoding latent. Inference took {vae_run_end - vae_run_start} seconds.") + print(f"took {vae_run_end - vae_run_start} seconds.") run_end_time = time.perf_counter() x.eval() @@ -113,10 +240,11 @@ def run(model, unconditional_context, context, latent, timestep, alphas, alphas_ # save image im = Image.fromarray(cp.from_dlpack(x).get().astype(np.uint8, copy=False)) print(f"saving {args.out}") + if not os.path.isdir("output"): + os.mkdir("output") im.save(args.out) - # Open image. - if not args.noshow: - im.show() + + return im, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end] def hf_diffusion(args): from diffusers import StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, AutoencoderKL @@ -130,49 +258,21 @@ def hf_diffusion(args): scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda") + run_start_time = time.perf_counter() + + print("Starting tokenization and running clip...", end=" ") + clip_run_start = time.perf_counter() text_input = hf_tokenizer(args.prompt, padding="max_length", max_length=hf_tokenizer.model_max_length, truncation=True, return_tensors="pt").to("cuda") max_length = text_input.input_ids.shape[-1] # 77 uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda") - text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0] uncond_embeddings = hf_encoder(uncond_input.input_ids)[0] - - from test_acc import check_equal - del pipe - model = StableDiffusion() - load_from_diffusers(model, tp.float32, debug=True) - - run_start_time = time.perf_counter() - - # Run through CLIP to get context - tokenizer = ClipTokenizer() - prompt = tp.Tensor([tokenizer.encode(args.prompt)]) - print(f"Got tokenized prompt.") - unconditional_prompt = tp.Tensor([tokenizer.encode("")]) - print(f"Got unconditional tokenized prompt.") - - print("Compiling CLIP model...") - clip_compile_start_time = time.perf_counter() - clip_compiler = tp.Compiler(model.cond_stage_model.transformer.text_model) - clip_text_model = clip_compiler.compile(tp.InputInfo((1, 77), dtype=tp.int32)) - clip_compile_end_time = time.perf_counter() - print(f"Compilation of CLIP took {clip_compile_end_time - clip_compile_start_time} seconds.") - - print("Getting CLIP context...") - clip_run_start = time.perf_counter() - context = clip_text_model(prompt) - unconditional_context = clip_text_model(unconditional_prompt) clip_run_end = time.perf_counter() - print(f"Got CLIP conditional and unconditional context. Inference took {clip_run_end - clip_run_start} seconds.") - check_equal(context, text_embeddings, debug=True) - check_equal(unconditional_context, uncond_embeddings, debug=True) - - print("DONE") + print(f"took {clip_run_end - clip_run_start} seconds.") - # HF DIFFUSERS UNET - # start with random noise - # if args.seed is not None: - torch.manual_seed(0) + # Diffusion loop with UNet + if args.seed is not None: + torch.manual_seed(args.seed) torch_latent = torch.randn((1, 4, 64, 64)).to("cuda") text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) scheduler.set_timesteps(args.steps) @@ -192,19 +292,34 @@ def hf_diffusion(args): # compute the previous noisy sample x_t -> x_t-1 latents = scheduler.step(noise_pred, 999, torch_latent).prev_sample - print(f"TORCH LATENT: {latents}") torch_latent = 1 / 0.18215 * torch_latent decoder_out = vae.decode(torch_latent) + +def print_summary(denoising_steps, times): + stages_ms = [1000 * (times[i+1] - times[i]) for i in range(0, 6, 2)] + total_ms = sum(stages_ms) + print('|-----------------|--------------|') + print('| {:^15} | {:^12} |'.format('Module', 'Latency')) + print('|-----------------|--------------|') + print('| {:^15} | {:>9.2f} ms |'.format('CLIP', stages_ms[0])) + print('| {:^15} | {:>9.2f} ms |'.format('UNet'+' x '+str(denoising_steps), stages_ms[1])) + print('| {:^15} | {:>9.2f} ms |'.format('VAE-Dec', stages_ms[2])) + print('|-----------------|--------------|') + print('| {:^15} | {:>9.2f} ms |'.format('Pipeline', total_ms)) + print('|-----------------|--------------|') + print('Throughput: {:.2f} image/s'.format(1000. / total_ms)) + + def main(): default_prompt = "a horse sized cat eating a bagel" parser = argparse.ArgumentParser( description="Run Stable Diffusion", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("--steps", type=int, default=10, help="Number of steps in diffusion") + parser.add_argument("--steps", type=int, default=10, help="Number of denoising steps in diffusion") parser.add_argument("--prompt", type=str, default=default_prompt, help="Phrase to render") - parser.add_argument("--out", type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") + parser.add_argument("--out", type=str, default=os.path.join("output", "rendered.png"), help="Output filename") parser.add_argument("--noshow", action="store_true", help="Don't show the image") parser.add_argument("--fp16", action="store_true", help="Cast the weights to float16") parser.add_argument("--timing", action="store_true", help="Print timing per step") @@ -216,7 +331,8 @@ def main(): if args.torch_inference: hf_diffusion(args) else: - tripy_diffusion(args) + _, times = tripy_diffusion(args) + print_summary(args.steps, times) if __name__ == "__main__": main() \ No newline at end of file diff --git a/tripy/examples/diffusion/model.py b/tripy/examples/diffusion/model.py index fb2851b6c..7d8203281 100644 --- a/tripy/examples/diffusion/model.py +++ b/tripy/examples/diffusion/model.py @@ -11,6 +11,19 @@ import numpy as np import tripy as tp +from dataclasses import dataclass + +# @dataclass +# class StableDiffusion15Config: +# block_size: int = 1024 +# vocab_size: int = 50257 +# num_layers: int = 12 +# num_heads: int = 12 +# embedding_size: int = 768 +# bias: bool = True +# seq_len: int = 1 +# batch_size: int = 1 +# dtype: "tripy.datatype" = tp.float32 # convenience methods adapted from tinygrad/tensor.py (https://docs.tinygrad.org/tensor/ops/) def scaled_dot_product_attention( @@ -47,9 +60,6 @@ def sequential(input: tp.Tensor, ll: List[Callable[[tp.Tensor], tp.Tensor]]): return reduce(lambda x, f: f(x), ll, input) -# convenience for dynamic reshapes -one_shape = tp.Shape(tp.ones((1,), dtype=tp.int32)) - # TODO: change to linear layers? class AttnBlock(tp.Module): def __init__(self, in_channels): @@ -209,7 +219,7 @@ def __init__(self, channels, emb_channels, out_channels): def __call__(self, x, emb): h = self.conv1(self.nonlinearity(self.norm1(x))) emb_out = self.time_emb_proj(self.nonlinearity(emb)) - target_shape = tp.concatenate([emb_out.shape, one_shape, one_shape], dim=0) + target_shape = emb_out.shape + (1, 1) # TODO: #228: WAR to prevent computing output rank in infer_rank for reshape target_shape.trace_tensor.shape = (emb_out.rank + 2,) h = h + tp.reshape(emb_out, target_shape) @@ -262,7 +272,7 @@ class FeedForward(tp.Module): def __init__(self, dim, mult=4): self.net = [ GEGLU(dim, dim * mult), - Dummy(), # needed for weights loading code to work + Dummy(), # Accounts for Dropout layer, needed for weight loading tp.Linear(dim * mult, dim), ] @@ -727,7 +737,7 @@ def clamp(tensor: tp.Tensor, min: int, max: int): class StableDiffusion(tp.Module): def __init__(self): - self.alphas_cumprod = get_alphas_cumprod() + self.alphas_cumprod = get_alphas_cumprod().eval() self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model=UNetModel()) self.first_stage_model = AutoencoderKL() self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])( diff --git a/tripy/tripy/frontend/module/module.py b/tripy/tripy/frontend/module/module.py index fbc1d1597..f1095266a 100644 --- a/tripy/tripy/frontend/module/module.py +++ b/tripy/tripy/frontend/module/module.py @@ -106,7 +106,7 @@ def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, List) or isinstance(value, Dict): container = value if isinstance(value, List) else value.values() - if _contains_types(container, [Parameter, Module]) and not _is_homogeneous_container(container, Parameter): + if _contains_types(container, [Parameter, Module]) and not (_is_homogeneous_container(container, Parameter) or _is_homogeneous_container(container, Module)): logger.warning("A container of mixed types will not be registered with this module's state_dict().") def state_dict(self) -> Dict[str, Parameter]: