diff --git a/train.py b/train.py index 592a761..69dfae0 100644 --- a/train.py +++ b/train.py @@ -706,16 +706,21 @@ def main( # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") + + unet_train_enabled = False + text_train_enabled = False def finetune_unet(batch, train_encoder=False): nonlocal use_offset_noise nonlocal rescale_schedule - + nonlocal unet_train_enabled + nonlocal text_train_enabled + # Check if we are training the text encoder - text_trainable = (train_text_encoder or lora_manager.use_text_lora) + text_trainable = (train_text_encoder or use_text_lora) # Unfreeze UNET Layers - if global_step == 0: + if global_step == 0 and not unet_train_enabled: already_printed_trainables = False unet.train() handle_trainable_modules( @@ -724,6 +729,7 @@ def finetune_unet(batch, train_encoder=False): is_enabled=True, negation=unet_negation ) + unet_train_enabled = True # Convert videos to latent space pixel_values = batch["pixel_values"] @@ -736,9 +742,6 @@ def finetune_unet(batch, train_encoder=False): # Get video length video_length = latents.shape[2] - # Sample noise that we'll add to the latents - use_offset_noise = use_offset_noise and not rescale_schedule - noise = sample_noise(latents, offset_noise_strength, use_offset_noise) bsz = latents.shape[0] # Sample a random timestep for each video @@ -747,10 +750,16 @@ def finetune_unet(batch, train_encoder=False): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) + #latents = rearrange(latents, 'b c f h w -> (b f) c h w') + + # Sample noise that we'll add to the latents + use_offset_noise = use_offset_noise and not rescale_schedule + noise = sample_noise(latents, offset_noise_strength, use_offset_noise) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - + # Enable text encoder training - if text_trainable: + if text_trainable and not text_train_enabled: text_encoder.train() if lora_manager.use_text_lora: @@ -763,6 +772,7 @@ def finetune_unet(batch, train_encoder=False): negation=text_encoder_negation ) cast_to_gpu_and_type([text_encoder], accelerator, torch.float32) + text_train_enabled = True # *Potentially* Fixes gradient checkpointing training. # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb @@ -783,41 +793,22 @@ def finetune_unet(batch, train_encoder=False): else: raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") - - - # Here we do two passes for video and text training. - # If we are on the second iteration of the loop, get one frame. + # This allows us to train text information only on the spatial layers. - losses = [] - should_truncate_video = (video_length > 1 and text_trainable) + should_detach = video_length > 1 # We detach the encoder hidden states for the first pass (video frames > 1) # Then we make a clone of the initial state to ensure we can train it in the loop. detached_encoder_state = encoder_hidden_states.clone().detach() trainable_encoder_state = encoder_hidden_states.clone() - for i in range(2): - - should_detach = noisy_latents.shape[2] > 1 and i == 0 - - if should_truncate_video and i == 1: - noisy_latents = noisy_latents[:,:,1,:,:].unsqueeze(2) - target = target[:,:,1,:,:].unsqueeze(2) - - encoder_hidden_states = ( - detached_encoder_state if should_detach else trainable_encoder_state - ) - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - losses.append(loss) - - # This was most likely single frame training or a single image. - if video_length == 1 and i == 0: break - - loss = losses[0] if len(losses) == 1 else losses[0] + losses[1] + encoder_hidden_states = ( + detached_encoder_state if should_detach else trainable_encoder_state + ) + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + return loss, latents for epoch in range(first_epoch, num_train_epochs):