From be2c94e79754c1c188fcce1a6a9887b1969b7f81 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Thu, 22 Aug 2024 20:41:34 -0700 Subject: [PATCH] Remove inlined CLIP model, use K-LMS scheduler for HF Signed-off-by: Akhil Goel --- tripy/examples/diffusion/example.py | 100 ++++++++++++------ tripy/examples/diffusion/model.py | 156 +--------------------------- 2 files changed, 73 insertions(+), 183 deletions(-) diff --git a/tripy/examples/diffusion/example.py b/tripy/examples/diffusion/example.py index 179186c67..c2131aecf 100644 --- a/tripy/examples/diffusion/example.py +++ b/tripy/examples/diffusion/example.py @@ -24,7 +24,8 @@ import cupy as cp import numpy as np -from model import ClipTokenizer, StableDiffusion, get_alphas_cumprod +from transformers import CLIPTokenizer +from model import CLIPConfig, StableDiffusion, get_alphas_cumprod from weight_loader import load_from_diffusers import tripy as tp @@ -102,6 +103,7 @@ def tripy_diffusion(args): # vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json")) # else: model = StableDiffusion() + print("[I] Loading model weights...", flush=True) load_from_diffusers(model, tp.float32, debug=True) clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True) unet_compiled = compile_unet(model, verbose=True) @@ -114,10 +116,12 @@ def tripy_diffusion(args): # vae_compiled.save(os.path.join("engines", "vae_executable.json")) # Run through CLIP to get context from prompt - tokenizer = ClipTokenizer() - prompt = tp.Tensor([tokenizer.encode(args.prompt)]) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + torch_prompt = tokenizer(args.prompt, padding="max_length", max_length=CLIPConfig.max_seq_len, truncation=True, return_tensors="pt") + prompt = tp.Tensor(torch_prompt.input_ids.to(torch.int32).to("cuda")) print(f"[I] Got tokenized prompt.") - unconditional_prompt = tp.Tensor([tokenizer.encode("")]) + torch_unconditional_prompt = tokenizer([""], padding="max_length", max_length=CLIPConfig.max_seq_len, return_tensors="pt") + unconditional_prompt = tp.Tensor(torch_unconditional_prompt.input_ids.to(torch.int32).to("cuda")) print(f"[I] Got unconditional tokenized prompt.") print("[I] Getting CLIP conditional and unconditional context...", end=" ") @@ -150,31 +154,36 @@ def tripy_diffusion(args): run_end_time = time.perf_counter() print(f"[I] Full script took {run_end_time - run_start_time} seconds.") - # save image - im = Image.fromarray(cp.from_dlpack(x).get().astype(np.uint8, copy=False)) + # Save image + image = Image.fromarray(cp.from_dlpack(x).get().astype(np.uint8, copy=False)) print(f"[I] Saving {args.out}") if not os.path.isdir("output"): print("[I] Creating 'output' directory.") os.mkdir("output") - im.save(args.out) - - return im, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end] + image.save(args.out) + return image, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end] +# referenced from https://huggingface.co/blog/stable_diffusion def hf_diffusion(args): - from diffusers import StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, AutoencoderKL + from transformers import CLIPTextModel, CLIPTokenizer + from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler + from tqdm.auto import tqdm + + run_start_time = time.perf_counter() + # Initialize models model_id = "runwayml/stable-diffusion-v1-5" - pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float32) - pipe = pipe.to("cuda") - hf_tokenizer = pipe.tokenizer - hf_encoder = pipe.text_encoder.to("cuda") + clip_id = "openai/clip-vit-large-patch14" + + print("[I] Loading models...") + hf_tokenizer = CLIPTokenizer.from_pretrained(clip_id) + hf_encoder = CLIPTextModel.from_pretrained(clip_id).to("cuda") unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda") - scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda") - - run_start_time = time.perf_counter() + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + # Run through CLIP to get context from prompt print("[I] Starting tokenization and running clip...", end=" ") clip_run_start = time.perf_counter() text_input = hf_tokenizer(args.prompt, padding="max_length", max_length=hf_tokenizer.model_max_length, truncation=True, return_tensors="pt").to("cuda") @@ -182,35 +191,64 @@ def hf_diffusion(args): uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda") text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0] uncond_embeddings = hf_encoder(uncond_input.input_ids)[0] + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) clip_run_end = time.perf_counter() print(f"took {clip_run_end - clip_run_start} seconds.") - # Diffusion loop with UNet + # Backbone of diffusion - the UNet if args.seed is not None: torch.manual_seed(args.seed) torch_latent = torch.randn((1, 4, 64, 64)).to("cuda") - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + torch_latent *= scheduler.init_noise_sigma + scheduler.set_timesteps(args.steps) - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - latent_model_input = torch.cat([torch_latent] * 2) + diffusion_run_start = time.perf_counter() + for t in tqdm(scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([torch_latent] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) - latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=999) + # predict the noise residual + with torch.no_grad(): + noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # predict the noise residual - with torch.no_grad(): - noise_pred = unet(latent_model_input, 999, encoder_hidden_states=text_embeddings).sample + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + args.guidance * (noise_pred_text - noise_pred_uncond) - # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + args.guidance * (noise_pred_text - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 + torch_latent = scheduler.step(noise_pred, t, torch_latent).prev_sample - # compute the previous noisy sample x_t -> x_t-1 - latents = scheduler.step(noise_pred, 999, torch_latent).prev_sample + diffusion_run_end = time.perf_counter() + print(f"[I] Finished diffusion denoising. Inference took {diffusion_run_end - diffusion_run_start} seconds.") + # Upsample latent space to image with autoencoder + print(f"[I] Decoding latent...", end=" ") + vae_run_start = time.perf_counter() torch_latent = 1 / 0.18215 * torch_latent - decoder_out = vae.decode(torch_latent) + with torch.no_grad(): + image = vae.decode(torch_latent).sample + vae_run_end = time.perf_counter() + print(f"took {vae_run_end - vae_run_start} seconds.") + + # Evaluate Output + image = (image / 2 + 0.5).clamp(0, 1) + image = image.detach().cpu().permute(0, 2, 3, 1).numpy() + images = (image * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + image = pil_images[0] + run_end_time = time.perf_counter() + print(f"[I] Full script took {run_end_time - run_start_time} seconds.") + + # Save image + print(f"[I] Saving {args.out}") + if not os.path.isdir("output"): + print("[I] Creating 'output' directory.") + os.mkdir("output") + image.save(args.out) + return image, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end] def print_summary(denoising_steps, times): stages_ms = [1000 * (times[i+1] - times[i]) for i in range(0, 6, 2)] diff --git a/tripy/examples/diffusion/model.py b/tripy/examples/diffusion/model.py index 67f76fb17..c3542c836 100644 --- a/tripy/examples/diffusion/model.py +++ b/tripy/examples/diffusion/model.py @@ -24,7 +24,7 @@ from functools import lru_cache, reduce from tqdm import tqdm from collections import namedtuple -from typing import List, Callable, Optional, Union +from typing import List, Tuple, Callable, Optional, Union import numpy as np import tripy as tp @@ -43,8 +43,8 @@ class CLIPConfig: class StableDiffusion15UNetConfig: io_channels: int = 4 model_channels: int = 320 - channel_mult: List[int] = [1, 2, 4, 4] - attention_resolutions: List[int] = [4, 2, 1] + channel_mult: Tuple[int] = (1, 2, 4, 4) + attention_resolutions: Tuple[int] = (4, 2, 1) num_heads: int = 8 context_dim: int = 768 dtype: "tripy.datatype" = tp.float32 @@ -55,7 +55,7 @@ class StableDiffusionVAEConfig: latent_channels: int = 4 model_channel: int = 128 resolution: int = 256 - channel_mult: List[int] = [1, 2, 4, 4] + channel_mult: Tuple[int] = (1, 2, 4, 4) dtype: "tripy.datatype" = tp.float32 # convenience methods adapted from tinygrad/tensor.py (https://docs.tinygrad.org/tensor/ops/) @@ -602,154 +602,6 @@ def __call__(self, input_ids): return self.final_layer_norm(x) -# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) -@lru_cache() -def default_bpe(): - return fetch( - "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz" - ) - - -@lru_cache(maxsize=None) -def getenv(key: str, default=0): - return type(default)(os.getenv(key, default)) - - -OSX = platform.system() == "Darwin" -_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")) - - -def fetch( - url: str, name: Optional[Union[pathlib.Path, str]] = None, allow_caching=not getenv("DISABLE_HTTP_CACHE") -) -> pathlib.Path: - if url.startswith(("/", ".")): - return pathlib.Path(url) - fp = ( - pathlib.Path(name) - if name is not None and (isinstance(name, pathlib.Path) or "/" in name) - else pathlib.Path(_cache_dir) - / "tinygrad" - / "downloads" - / (name if name else hashlib.md5(url.encode("utf-8")).hexdigest()) - ) # noqa: E501 - if not fp.is_file() or not allow_caching: - with urllib.request.urlopen(url, timeout=10) as r: - assert r.status == 200 - total_length = int(r.headers.get("content-length", 0)) - progress_bar = tqdm(total=total_length, unit="B", unit_scale=True, desc=url) - (path := fp.parent).mkdir(parents=True, exist_ok=True) - with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: - while chunk := r.read(16384): - progress_bar.update(f.write(chunk)) - f.close() - if (file_size := os.stat(f.name).st_size) < total_length: - raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}") - pathlib.Path(f.name).rename(fp) - return fp - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - return set(zip(word, word[1:])) - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -class ClipTokenizer: - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") - merges = merges[1 : 49152 - 256 - 2 + 1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v + "" for v in vocab] - for merge in merges: - vocab.append("".join(merge)) - vocab.extend(["<|startoftext|>", "<|endoftext|>"]) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + "",) - pairs = get_pairs(word) - - if not pairs: - return token + "" - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(text.strip()).lower() - for token in re.findall(self.pat, text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) - # Truncation, keeping two slots for start and end tokens. - if len(bpe_tokens) > 75: - bpe_tokens = bpe_tokens[:75] - return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1) - # equivalent to LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000): betas = np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32) ** 2