Skip to content

Commit

Permalink
fix: train config and data processing param adjustment
Browse files Browse the repository at this point in the history
* add model weight postprocess after stage1
* make data processing param easier to understand
  • Loading branch information
xumingw committed Jun 27, 2024
1 parent fc2d31a commit ce380f3
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 31 deletions.
4 changes: 2 additions & 2 deletions configs/train/stage2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ start_ratio: 0.05
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
stage1_ckpt_dir: "./pretrained_models/hallo/stage1"
stage1_ckpt_dir: "./exp_output/stage1/"

single_inference_times: 10
inference_steps: 40
Expand All @@ -107,7 +107,7 @@ cfg_scale: 3.5
seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage2_test"
exp_name: "stage2"
output_dir: "./exp_output"

ref_img_path:
Expand Down
4 changes: 2 additions & 2 deletions hallo/datasets/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ class ImageProcessorForDataProcessing():
__exit__(_exc_type, _exc_val, _exc_tb):
Exits a runtime context and handles any exceptions that occurred during the processing.
"""
def __init__(self, face_analysis_model_path, landmark_model_path, gpu_status) -> None:
if gpu_status:
def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
if step == 2:
self.face_analysis = FaceAnalysis(
name="",
root=face_analysis_model_path,
Expand Down
49 changes: 39 additions & 10 deletions hallo/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ def compute_snr(noise_scheduler, timesteps):
snr = (alpha / sigma) ** 2
return snr

def extract_audio_from_videos(video_path: Path, output_dir: Path) -> Path:

def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path:
"""
Extract audio from a video file and save it as a WAV file.
Expand All @@ -867,10 +868,6 @@ def extract_audio_from_videos(video_path: Path, output_dir: Path) -> Path:
Raises:
subprocess.CalledProcessError: If the ffmpeg command fails to execute.
"""
audio_output_dir = output_dir / 'audios'
audio_output_dir.mkdir(parents=True, exist_ok=True)
audio_output_path = audio_output_dir / f'{video_path.stem}.wav'

ffmpeg_command = [
'ffmpeg', '-y',
'-i', str(video_path),
Expand Down Expand Up @@ -906,14 +903,11 @@ def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
Raises:
subprocess.CalledProcessError: If the ffmpeg command fails to execute.
"""
images_output_dir = output_dir / 'images' / video_path.stem
images_output_dir.mkdir(parents=True, exist_ok=True)

ffmpeg_command = [
'ffmpeg',
'-i', str(video_path),
'-vf', 'fps=25',
str(images_output_dir / '%04d.png')
str(output_dir / '%04d.png')
]

try:
Expand All @@ -923,10 +917,22 @@ def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
print(f"Error converting video to images: {e}")
raise

return images_output_dir
return output_dir


def get_union_mask(masks):
"""
Compute the union of a list of masks.
This function takes a list of masks and computes their union by taking the maximum value at each pixel location.
Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white.
Args:
masks (list of np.ndarray): List of masks to be combined.
Returns:
np.ndarray: The union of the input masks.
"""
union_mask = None
for mask in masks:
if union_mask is None:
Expand All @@ -949,3 +955,26 @@ def get_union_mask(masks):
union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask)

return union_mask


def move_final_checkpoint(save_dir, module_dir, prefix):
"""
Move the final checkpoint file to the save directory.
This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory.
Args:
save_dir (str): The directory where the final checkpoint file should be saved.
module_dir (str): The directory containing the checkpoint files.
prefix (str): The prefix used to identify checkpoint files.
Raises:
ValueError: If no checkpoint files are found with the specified prefix.
"""
checkpoints = os.listdir(module_dir)
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)
shutil.copy2(os.path.join(
module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth'))
37 changes: 22 additions & 15 deletions scripts/data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def process_single_video(video_path: Path,
output_dir: Path,
image_processor: ImageProcessorForDataProcessing,
audio_processor: AudioProcessor,
gpu_status: bool) -> None:
step: int) -> None:
"""
Process a single video file.
Expand All @@ -76,15 +76,22 @@ def process_single_video(video_path: Path,
logging.info(f"Processing video: {video_path}")

try:
if not gpu_status:
images_dir = convert_video_to_images(video_path, output_dir)
logging.info(f"Images saved to: {images_dir}")

audio_path = extract_audio_from_videos(video_path, output_dir)
logging.info(f"Audio extracted to: {audio_path}")
if step == 1:
images_output_dir = output_dir / 'images' / video_path.stem
images_output_dir.mkdir(parents=True, exist_ok=True)
images_output_dir = convert_video_to_images(
video_path, images_output_dir)
logging.info(f"Images saved to: {images_output_dir}")

audio_output_dir = output_dir / 'audios'
audio_output_dir.mkdir(parents=True, exist_ok=True)
audio_output_path = audio_output_dir / f'{video_path.stem}.wav'
audio_output_path = extract_audio_from_videos(
video_path, audio_output_path)
logging.info(f"Audio extracted to: {audio_output_path}")

face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess(
images_dir)
images_output_dir)
cv2.imwrite(
str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask)
cv2.imwrite(str(dirs["sep_pose_mask"] /
Expand All @@ -106,7 +113,7 @@ def process_single_video(video_path: Path,
logging.error(f"Failed to process video {video_path}: {e}")


def process_all_videos(input_video_list: List[Path], output_dir: Path, gpu_status: bool) -> None:
def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None:
"""
Process all videos in the input list.
Expand All @@ -128,14 +135,14 @@ def process_all_videos(input_video_list: List[Path], output_dir: Path, gpu_statu
os.path.dirname(audio_separator_model_file),
os.path.basename(audio_separator_model_file),
os.path.join(output_dir, "vocals"),
) if gpu_status else None
) if step==2 else None

image_processor = ImageProcessorForDataProcessing(
face_analysis_model_path, landmark_model_path, gpu_status)
face_analysis_model_path, landmark_model_path, step)

for video_path in tqdm(input_video_list, desc="Processing videos"):
process_single_video(video_path, output_dir,
image_processor, audio_processor, gpu_status)
image_processor, audio_processor, step)


def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]:
Expand Down Expand Up @@ -163,8 +170,8 @@ def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]
required=True, help="Directory containing videos")
parser.add_argument("-o", "--output_dir", type=Path,
help="Directory to save results, default is parent dir of input dir")
parser.add_argument("-g", "--gpu_status", action='store_true',
help="Run tasks requiring GPU or tasks not requiring GPU")
parser.add_argument("-s", "--step", type=int, default=1,
help="Specify data processing step 1 or 2, you should run 1 and 2 sequently")
parser.add_argument("-p", "--parallelism", default=1,
type=int, help="Level of parallelism")
parser.add_argument("-r", "--rank", default=0, type=int,
Expand All @@ -181,4 +188,4 @@ def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]
if not video_path_list:
logging.warning("No videos to process.")
else:
process_all_videos(video_path_list, args.output_dir, args.gpu_status)
process_all_videos(video_path_list, args.output_dir, args.step)
10 changes: 8 additions & 2 deletions scripts/train_stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
from hallo.models.unet_3d import UNet3DConditionModel
from hallo.utils.util import (compute_snr, delete_additional_ckpt,
import_filename, init_output_dir,
load_checkpoint, save_checkpoint,
seed_everything)
load_checkpoint, move_final_checkpoint,
save_checkpoint, seed_everything)

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -747,6 +747,12 @@ def train_stage1_process(cfg: argparse.Namespace) -> None:
progress_bar.set_postfix(**logs)

if global_step >= cfg.solver.max_train_steps:
# process final module weight for stage2
if accelerator.is_main_process:
move_final_checkpoint(save_dir, module_dir, "reference_unet")
move_final_checkpoint(save_dir, module_dir, "imageproj")
move_final_checkpoint(save_dir, module_dir, "denoising_unet")
move_final_checkpoint(save_dir, module_dir, "face_locator")
break

accelerator.wait_for_everyone()
Expand Down

0 comments on commit ce380f3

Please sign in to comment.