Skip to content
This repository has been archived by the owner on Dec 14, 2023. It is now read-only.

Simplify training loop #111

Open
wants to merge 2 commits into
base: rc/dev-v4
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 26 additions & 35 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,21 @@ def main(
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

unet_train_enabled = False
text_train_enabled = False

def finetune_unet(batch, train_encoder=False):
nonlocal use_offset_noise
nonlocal rescale_schedule

nonlocal unet_train_enabled
nonlocal text_train_enabled

# Check if we are training the text encoder
text_trainable = (train_text_encoder or lora_manager.use_text_lora)
text_trainable = (train_text_encoder or use_text_lora)

# Unfreeze UNET Layers
if global_step == 0:
if global_step == 0 and not unet_train_enabled:
already_printed_trainables = False
unet.train()
handle_trainable_modules(
Expand All @@ -724,6 +729,7 @@ def finetune_unet(batch, train_encoder=False):
is_enabled=True,
negation=unet_negation
)
unet_train_enabled = True

# Convert videos to latent space
pixel_values = batch["pixel_values"]
Expand All @@ -736,9 +742,6 @@ def finetune_unet(batch, train_encoder=False):
# Get video length
video_length = latents.shape[2]

# Sample noise that we'll add to the latents
use_offset_noise = use_offset_noise and not rescale_schedule
noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
bsz = latents.shape[0]

# Sample a random timestep for each video
Expand All @@ -747,10 +750,16 @@ def finetune_unet(batch, train_encoder=False):

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
#latents = rearrange(latents, 'b c f h w -> (b f) c h w')

# Sample noise that we'll add to the latents
use_offset_noise = use_offset_noise and not rescale_schedule
noise = sample_noise(latents, offset_noise_strength, use_offset_noise)

noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# Enable text encoder training
if text_trainable:
if text_trainable and not text_train_enabled:
text_encoder.train()

if lora_manager.use_text_lora:
Expand All @@ -763,6 +772,7 @@ def finetune_unet(batch, train_encoder=False):
negation=text_encoder_negation
)
cast_to_gpu_and_type([text_encoder], accelerator, torch.float32)
text_train_enabled = True

# *Potentially* Fixes gradient checkpointing training.
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
Expand All @@ -783,41 +793,22 @@ def finetune_unet(batch, train_encoder=False):

else:
raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")


# Here we do two passes for video and text training.
# If we are on the second iteration of the loop, get one frame.

# This allows us to train text information only on the spatial layers.
losses = []
should_truncate_video = (video_length > 1 and text_trainable)
should_detach = video_length > 1

# We detach the encoder hidden states for the first pass (video frames > 1)
# Then we make a clone of the initial state to ensure we can train it in the loop.
detached_encoder_state = encoder_hidden_states.clone().detach()
trainable_encoder_state = encoder_hidden_states.clone()

for i in range(2):

should_detach = noisy_latents.shape[2] > 1 and i == 0

if should_truncate_video and i == 1:
noisy_latents = noisy_latents[:,:,1,:,:].unsqueeze(2)
target = target[:,:,1,:,:].unsqueeze(2)

encoder_hidden_states = (
detached_encoder_state if should_detach else trainable_encoder_state
)

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

losses.append(loss)

# This was most likely single frame training or a single image.
if video_length == 1 and i == 0: break

loss = losses[0] if len(losses) == 1 else losses[0] + losses[1]
encoder_hidden_states = (
detached_encoder_state if should_detach else trainable_encoder_state
)

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

return loss, latents

for epoch in range(first_epoch, num_train_epochs):
Expand Down