Skip to content

Commit

Permalink
Remove inlined CLIP model, use K-LMS scheduler for HF
Browse files Browse the repository at this point in the history
Signed-off-by: Akhil Goel <[email protected]>
  • Loading branch information
akhilg-nv committed Aug 29, 2024
1 parent 7010874 commit be2c94e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 183 deletions.
100 changes: 69 additions & 31 deletions tripy/examples/diffusion/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import cupy as cp
import numpy as np

from model import ClipTokenizer, StableDiffusion, get_alphas_cumprod
from transformers import CLIPTokenizer
from model import CLIPConfig, StableDiffusion, get_alphas_cumprod
from weight_loader import load_from_diffusers
import tripy as tp

Expand Down Expand Up @@ -102,6 +103,7 @@ def tripy_diffusion(args):
# vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json"))
# else:
model = StableDiffusion()
print("[I] Loading model weights...", flush=True)
load_from_diffusers(model, tp.float32, debug=True)
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
unet_compiled = compile_unet(model, verbose=True)
Expand All @@ -114,10 +116,12 @@ def tripy_diffusion(args):
# vae_compiled.save(os.path.join("engines", "vae_executable.json"))

# Run through CLIP to get context from prompt
tokenizer = ClipTokenizer()
prompt = tp.Tensor([tokenizer.encode(args.prompt)])
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
torch_prompt = tokenizer(args.prompt, padding="max_length", max_length=CLIPConfig.max_seq_len, truncation=True, return_tensors="pt")
prompt = tp.Tensor(torch_prompt.input_ids.to(torch.int32).to("cuda"))
print(f"[I] Got tokenized prompt.")
unconditional_prompt = tp.Tensor([tokenizer.encode("")])
torch_unconditional_prompt = tokenizer([""], padding="max_length", max_length=CLIPConfig.max_seq_len, return_tensors="pt")
unconditional_prompt = tp.Tensor(torch_unconditional_prompt.input_ids.to(torch.int32).to("cuda"))
print(f"[I] Got unconditional tokenized prompt.")

print("[I] Getting CLIP conditional and unconditional context...", end=" ")
Expand Down Expand Up @@ -150,67 +154,101 @@ def tripy_diffusion(args):
run_end_time = time.perf_counter()
print(f"[I] Full script took {run_end_time - run_start_time} seconds.")

# save image
im = Image.fromarray(cp.from_dlpack(x).get().astype(np.uint8, copy=False))
# Save image
image = Image.fromarray(cp.from_dlpack(x).get().astype(np.uint8, copy=False))
print(f"[I] Saving {args.out}")
if not os.path.isdir("output"):
print("[I] Creating 'output' directory.")
os.mkdir("output")
im.save(args.out)

return im, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end]
image.save(args.out)

return image, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end]

# referenced from https://huggingface.co/blog/stable_diffusion
def hf_diffusion(args):
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from tqdm.auto import tqdm

run_start_time = time.perf_counter()

# Initialize models
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float32)
pipe = pipe.to("cuda")
hf_tokenizer = pipe.tokenizer
hf_encoder = pipe.text_encoder.to("cuda")
clip_id = "openai/clip-vit-large-patch14"

print("[I] Loading models...")
hf_tokenizer = CLIPTokenizer.from_pretrained(clip_id)
hf_encoder = CLIPTextModel.from_pretrained(clip_id).to("cuda")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda")

run_start_time = time.perf_counter()
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# Run through CLIP to get context from prompt
print("[I] Starting tokenization and running clip...", end=" ")
clip_run_start = time.perf_counter()
text_input = hf_tokenizer(args.prompt, padding="max_length", max_length=hf_tokenizer.model_max_length, truncation=True, return_tensors="pt").to("cuda")
max_length = text_input.input_ids.shape[-1] # 77
uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda")
text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0]
uncond_embeddings = hf_encoder(uncond_input.input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
clip_run_end = time.perf_counter()
print(f"took {clip_run_end - clip_run_start} seconds.")

# Diffusion loop with UNet
# Backbone of diffusion - the UNet
if args.seed is not None:
torch.manual_seed(args.seed)
torch_latent = torch.randn((1, 4, 64, 64)).to("cuda")
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
torch_latent *= scheduler.init_noise_sigma

scheduler.set_timesteps(args.steps)

# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([torch_latent] * 2)
diffusion_run_start = time.perf_counter()
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([torch_latent] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=999)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, 999, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + args.guidance * (noise_pred_text - noise_pred_uncond)

# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + args.guidance * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
torch_latent = scheduler.step(noise_pred, t, torch_latent).prev_sample

# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, 999, torch_latent).prev_sample
diffusion_run_end = time.perf_counter()
print(f"[I] Finished diffusion denoising. Inference took {diffusion_run_end - diffusion_run_start} seconds.")

# Upsample latent space to image with autoencoder
print(f"[I] Decoding latent...", end=" ")
vae_run_start = time.perf_counter()
torch_latent = 1 / 0.18215 * torch_latent
decoder_out = vae.decode(torch_latent)
with torch.no_grad():
image = vae.decode(torch_latent).sample
vae_run_end = time.perf_counter()
print(f"took {vae_run_end - vae_run_start} seconds.")

# Evaluate Output
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
image = pil_images[0]

run_end_time = time.perf_counter()
print(f"[I] Full script took {run_end_time - run_start_time} seconds.")

# Save image
print(f"[I] Saving {args.out}")
if not os.path.isdir("output"):
print("[I] Creating 'output' directory.")
os.mkdir("output")
image.save(args.out)
return image, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end]

def print_summary(denoising_steps, times):
stages_ms = [1000 * (times[i+1] - times[i]) for i in range(0, 6, 2)]
Expand Down
156 changes: 4 additions & 152 deletions tripy/examples/diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import lru_cache, reduce
from tqdm import tqdm
from collections import namedtuple
from typing import List, Callable, Optional, Union
from typing import List, Tuple, Callable, Optional, Union

import numpy as np
import tripy as tp
Expand All @@ -43,8 +43,8 @@ class CLIPConfig:
class StableDiffusion15UNetConfig:
io_channels: int = 4
model_channels: int = 320
channel_mult: List[int] = [1, 2, 4, 4]
attention_resolutions: List[int] = [4, 2, 1]
channel_mult: Tuple[int] = (1, 2, 4, 4)
attention_resolutions: Tuple[int] = (4, 2, 1)
num_heads: int = 8
context_dim: int = 768
dtype: "tripy.datatype" = tp.float32
Expand All @@ -55,7 +55,7 @@ class StableDiffusionVAEConfig:
latent_channels: int = 4
model_channel: int = 128
resolution: int = 256
channel_mult: List[int] = [1, 2, 4, 4]
channel_mult: Tuple[int] = (1, 2, 4, 4)
dtype: "tripy.datatype" = tp.float32

# convenience methods adapted from tinygrad/tensor.py (https://docs.tinygrad.org/tensor/ops/)
Expand Down Expand Up @@ -602,154 +602,6 @@ def __call__(self, input_ids):
return self.final_layer_norm(x)


# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@lru_cache()
def default_bpe():
return fetch(
"https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz"
)


@lru_cache(maxsize=None)
def getenv(key: str, default=0):
return type(default)(os.getenv(key, default))


OSX = platform.system() == "Darwin"
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))


def fetch(
url: str, name: Optional[Union[pathlib.Path, str]] = None, allow_caching=not getenv("DISABLE_HTTP_CACHE")
) -> pathlib.Path:
if url.startswith(("/", ".")):
return pathlib.Path(url)
fp = (
pathlib.Path(name)
if name is not None and (isinstance(name, pathlib.Path) or "/" in name)
else pathlib.Path(_cache_dir)
/ "tinygrad"
/ "downloads"
/ (name if name else hashlib.md5(url.encode("utf-8")).hexdigest())
) # noqa: E501
if not fp.is_file() or not allow_caching:
with urllib.request.urlopen(url, timeout=10) as r:
assert r.status == 200
total_length = int(r.headers.get("content-length", 0))
progress_bar = tqdm(total=total_length, unit="B", unit_scale=True, desc=url)
(path := fp.parent).mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
while chunk := r.read(16384):
progress_bar.update(f.write(chunk))
f.close()
if (file_size := os.stat(f.name).st_size) < total_length:
raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
pathlib.Path(f.name).rename(fp)
return fp


def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
return set(zip(word, word[1:]))


def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text


def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


class ClipTokenizer:
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)

if not pairs:
return token + "</w>"

while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
text = whitespace_clean(text.strip()).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
# Truncation, keeping two slots for start and end tokens.
if len(bpe_tokens) > 75:
bpe_tokens = bpe_tokens[:75]
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)

# equivalent to LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
betas = np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32) ** 2
Expand Down

0 comments on commit be2c94e

Please sign in to comment.