Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: training code of hallo #101

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions configs/train/stage1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
data:
train_bs: 8
train_width: 512
train_height: 512
meta_paths:
- "./data/HDTF_meta.json"
# Margin of frame indexes between ref and tgt images
sample_margin: 30

solver:
gradient_accumulation_steps: 1
mixed_precision: "no"
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: False
max_train_steps: 30000
max_grad_norm: 1.0
# lr
learning_rate: 1.0e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: "constant"

# optimizer
use_8bit_adam: False
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

val:
validation_steps: 500

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "scaled_linear"
steps_offset: 1
clip_sample: false

base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
face_analysis_model_path: "./pretrained_models/face_analysis"

weight_dtype: "fp16" # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
face_locator_pretrained: False

seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage1"
output_dir: "./exp_output"

ref_image_paths:
- "examples/reference_images/1.jpg"

mask_image_paths:
- "examples/masks/1.png"

119 changes: 119 additions & 0 deletions configs/train/stage2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
data:
train_bs: 4
val_bs: 1
train_width: 512
train_height: 512
fps: 25
sample_rate: 16000
n_motion_frames: 2
n_sample_frames: 14
audio_margin: 2
train_meta_paths:
- "./data/hdtf_split_stage2.json"

wav2vec_config:
audio_type: "vocals" # audio vocals
model_scale: "base" # base large
features: "all" # last avg all
model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
audio_separator:
model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
face_expand_ratio: 1.2

solver:
gradient_accumulation_steps: 1
mixed_precision: "no"
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: True
max_train_steps: 30000
max_grad_norm: 1.0
# lr
learning_rate: 1e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: "constant"

# optimizer
use_8bit_adam: True
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

val:
validation_steps: 1000

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false

unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
use_audio_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
audio_attention_dim: 768
stack_enable_blocks_name:
- "up"
- "down"
- "mid"
stack_enable_blocks_depth: [0,1,2,3]

trainable_para:
- audio_modules
- motion_modules

base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
face_analysis_model_path: "./pretrained_models/face_analysis"
mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"

weight_dtype: "fp16" # [fp16, fp32]
uncond_img_ratio: 0.05
uncond_audio_ratio: 0.05
uncond_ia_ratio: 0.05
start_ratio: 0.05
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
stage1_ckpt_dir: "./pretrained_models/hallo/stage1"

single_inference_times: 10
inference_steps: 40
cfg_scale: 3.5

seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage2_test"
output_dir: "./exp_output"

ref_img_path:
- "examples/reference_images/1.jpg"

audio_path:
- "examples/driving_audios/1.wav"


Binary file added examples/masks/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 8 additions & 4 deletions hallo/datasets/talk_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,29 @@ def __init__(
)
self.attn_transform_64 = transforms.Compose(
[
transforms.Resize((64,64)),
transforms.Resize(
(self.img_size[0] // 8, self.img_size[0] // 8)),
transforms.ToTensor(),
]
)
self.attn_transform_32 = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.Resize(
(self.img_size[0] // 16, self.img_size[0] // 16)),
transforms.ToTensor(),
]
)
self.attn_transform_16 = transforms.Compose(
[
transforms.Resize((16, 16)),
transforms.Resize(
(self.img_size[0] // 32, self.img_size[0] // 32)),
transforms.ToTensor(),
]
)
self.attn_transform_8 = transforms.Compose(
[
transforms.Resize((8, 8)),
transforms.Resize(
(self.img_size[0] // 64, self.img_size[0] // 64)),
transforms.ToTensor(),
]
)
Expand Down
1 change: 1 addition & 0 deletions hallo/models/motion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def extra_repr(self):
def set_use_memory_efficient_attention_xformers(
self,
use_memory_efficient_attention_xformers: bool,
attention_op = None,
):
"""
Sets the use of memory-efficient attention xformers for the VersatileAttention class.
Expand Down
148 changes: 148 additions & 0 deletions hallo/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import subprocess
import sys
from pathlib import Path
from typing import List

import av
import cv2
Expand Down Expand Up @@ -614,3 +615,150 @@ def get_face_region(image_path: str, detector):
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return None, None


def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
"""
Save the model's state_dict to a checkpoint file.

If `total_limit` is provided, this function will remove the oldest checkpoints
until the total number of checkpoints is less than the specified limit.

Args:
model (nn.Module): The model whose state_dict is to be saved.
save_dir (str): The directory where the checkpoint will be saved.
prefix (str): The prefix for the checkpoint file name.
ckpt_num (int): The checkpoint number to be saved.
total_limit (int, optional): The maximum number of checkpoints to keep.
Defaults to None, in which case no checkpoints will be removed.

Raises:
FileNotFoundError: If the save directory does not exist.
ValueError: If the checkpoint number is negative.
OSError: If there is an error saving the checkpoint.
"""

if not osp.exists(save_dir):
raise FileNotFoundError(
f"The save directory {save_dir} does not exist.")

if ckpt_num < 0:
raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")

save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")

if total_limit > 0:
checkpoints = os.listdir(save_dir)
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)

if len(checkpoints) >= total_limit:
num_to_remove = len(checkpoints) - total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
print(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
print(
f"Removing checkpoints: {', '.join(removing_checkpoints)}"
)

for removing_checkpoint in removing_checkpoints:
removing_checkpoint_path = osp.join(
save_dir, removing_checkpoint)
try:
os.remove(removing_checkpoint_path)
except OSError as e:
print(
f"Error removing checkpoint {removing_checkpoint_path}: {e}")

state_dict = model.state_dict()
try:
torch.save(state_dict, save_path)
print(f"Checkpoint saved at {save_path}")
except OSError as e:
raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e


def init_output_dir(dir_list: List[str]):
"""
Initialize the output directories.

This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.

Args:
dir_list (List[str]): List of directory paths to create.
"""
for path in dir_list:
os.makedirs(path, exist_ok=True)


def load_checkpoint(cfg, save_dir, accelerator):
"""
Load the most recent checkpoint from the specified directory.

This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
it starts training from scratch.

Args:
cfg: The configuration object containing training parameters.
save_dir (str): The directory where checkpoints are saved.
accelerator: The accelerator object for distributed training.

Returns:
int: The global step at which to resume training.
"""
if cfg.resume_from_checkpoint != "latest":
resume_dir = cfg.resume_from_checkpoint
else:
resume_dir = save_dir
# Get the most recent checkpoint
dirs = os.listdir(resume_dir)

dirs = [d for d in dirs if d.startswith("checkpoint")]
if len(dirs) > 0:
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.load_state(os.path.join(resume_dir, path))
accelerator.print(f"Resuming from checkpoint {path}")
global_step = int(path.split("-")[1])
else:
accelerator.print(
f"Could not find checkpoint under {resume_dir}, start training from scratch")
global_step = 0

return global_step


def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
# 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
Loading
Loading