diff --git a/configs/train/stage2.yaml b/configs/train/stage2.yaml index bd32c0dd..baa30729 100644 --- a/configs/train/stage2.yaml +++ b/configs/train/stage2.yaml @@ -98,7 +98,7 @@ start_ratio: 0.05 noise_offset: 0.05 snr_gamma: 5.0 enable_zero_snr: True -stage1_ckpt_dir: "./pretrained_models/hallo/stage1" +stage1_ckpt_dir: "./exp_output/stage1/" single_inference_times: 10 inference_steps: 40 @@ -107,7 +107,7 @@ cfg_scale: 3.5 seed: 42 resume_from_checkpoint: "latest" checkpointing_steps: 500 -exp_name: "stage2_test" +exp_name: "stage2" output_dir: "./exp_output" ref_img_path: diff --git a/hallo/datasets/image_processor.py b/hallo/datasets/image_processor.py index 57715d18..16515226 100644 --- a/hallo/datasets/image_processor.py +++ b/hallo/datasets/image_processor.py @@ -240,8 +240,8 @@ class ImageProcessorForDataProcessing(): __exit__(_exc_type, _exc_val, _exc_tb): Exits a runtime context and handles any exceptions that occurred during the processing. """ - def __init__(self, face_analysis_model_path, landmark_model_path, gpu_status) -> None: - if gpu_status: + def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None: + if step == 2: self.face_analysis = FaceAnalysis( name="", root=face_analysis_model_path, diff --git a/hallo/utils/util.py b/hallo/utils/util.py index 2598710e..9fe29175 100644 --- a/hallo/utils/util.py +++ b/hallo/utils/util.py @@ -850,7 +850,8 @@ def compute_snr(noise_scheduler, timesteps): snr = (alpha / sigma) ** 2 return snr -def extract_audio_from_videos(video_path: Path, output_dir: Path) -> Path: + +def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path: """ Extract audio from a video file and save it as a WAV file. @@ -867,10 +868,6 @@ def extract_audio_from_videos(video_path: Path, output_dir: Path) -> Path: Raises: subprocess.CalledProcessError: If the ffmpeg command fails to execute. """ - audio_output_dir = output_dir / 'audios' - audio_output_dir.mkdir(parents=True, exist_ok=True) - audio_output_path = audio_output_dir / f'{video_path.stem}.wav' - ffmpeg_command = [ 'ffmpeg', '-y', '-i', str(video_path), @@ -906,14 +903,11 @@ def convert_video_to_images(video_path: Path, output_dir: Path) -> Path: Raises: subprocess.CalledProcessError: If the ffmpeg command fails to execute. """ - images_output_dir = output_dir / 'images' / video_path.stem - images_output_dir.mkdir(parents=True, exist_ok=True) - ffmpeg_command = [ 'ffmpeg', '-i', str(video_path), '-vf', 'fps=25', - str(images_output_dir / '%04d.png') + str(output_dir / '%04d.png') ] try: @@ -923,10 +917,22 @@ def convert_video_to_images(video_path: Path, output_dir: Path) -> Path: print(f"Error converting video to images: {e}") raise - return images_output_dir + return output_dir def get_union_mask(masks): + """ + Compute the union of a list of masks. + + This function takes a list of masks and computes their union by taking the maximum value at each pixel location. + Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white. + + Args: + masks (list of np.ndarray): List of masks to be combined. + + Returns: + np.ndarray: The union of the input masks. + """ union_mask = None for mask in masks: if union_mask is None: @@ -949,3 +955,26 @@ def get_union_mask(masks): union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask) return union_mask + + +def move_final_checkpoint(save_dir, module_dir, prefix): + """ + Move the final checkpoint file to the save directory. + + This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory. + + Args: + save_dir (str): The directory where the final checkpoint file should be saved. + module_dir (str): The directory containing the checkpoint files. + prefix (str): The prefix used to identify checkpoint files. + + Raises: + ValueError: If no checkpoint files are found with the specified prefix. + """ + checkpoints = os.listdir(module_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + shutil.copy2(os.path.join( + module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth')) diff --git a/scripts/data_preprocess.py b/scripts/data_preprocess.py index 7dce91c7..92efc2fc 100644 --- a/scripts/data_preprocess.py +++ b/scripts/data_preprocess.py @@ -60,7 +60,7 @@ def process_single_video(video_path: Path, output_dir: Path, image_processor: ImageProcessorForDataProcessing, audio_processor: AudioProcessor, - gpu_status: bool) -> None: + step: int) -> None: """ Process a single video file. @@ -76,15 +76,22 @@ def process_single_video(video_path: Path, logging.info(f"Processing video: {video_path}") try: - if not gpu_status: - images_dir = convert_video_to_images(video_path, output_dir) - logging.info(f"Images saved to: {images_dir}") - - audio_path = extract_audio_from_videos(video_path, output_dir) - logging.info(f"Audio extracted to: {audio_path}") + if step == 1: + images_output_dir = output_dir / 'images' / video_path.stem + images_output_dir.mkdir(parents=True, exist_ok=True) + images_output_dir = convert_video_to_images( + video_path, images_output_dir) + logging.info(f"Images saved to: {images_output_dir}") + + audio_output_dir = output_dir / 'audios' + audio_output_dir.mkdir(parents=True, exist_ok=True) + audio_output_path = audio_output_dir / f'{video_path.stem}.wav' + audio_output_path = extract_audio_from_videos( + video_path, audio_output_path) + logging.info(f"Audio extracted to: {audio_output_path}") face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess( - images_dir) + images_output_dir) cv2.imwrite( str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask) cv2.imwrite(str(dirs["sep_pose_mask"] / @@ -106,7 +113,7 @@ def process_single_video(video_path: Path, logging.error(f"Failed to process video {video_path}: {e}") -def process_all_videos(input_video_list: List[Path], output_dir: Path, gpu_status: bool) -> None: +def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None: """ Process all videos in the input list. @@ -128,14 +135,14 @@ def process_all_videos(input_video_list: List[Path], output_dir: Path, gpu_statu os.path.dirname(audio_separator_model_file), os.path.basename(audio_separator_model_file), os.path.join(output_dir, "vocals"), - ) if gpu_status else None + ) if step==2 else None image_processor = ImageProcessorForDataProcessing( - face_analysis_model_path, landmark_model_path, gpu_status) + face_analysis_model_path, landmark_model_path, step) for video_path in tqdm(input_video_list, desc="Processing videos"): process_single_video(video_path, output_dir, - image_processor, audio_processor, gpu_status) + image_processor, audio_processor, step) def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]: @@ -163,8 +170,8 @@ def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path] required=True, help="Directory containing videos") parser.add_argument("-o", "--output_dir", type=Path, help="Directory to save results, default is parent dir of input dir") - parser.add_argument("-g", "--gpu_status", action='store_true', - help="Run tasks requiring GPU or tasks not requiring GPU") + parser.add_argument("-s", "--step", type=int, default=1, + help="Specify data processing step 1 or 2, you should run 1 and 2 sequently") parser.add_argument("-p", "--parallelism", default=1, type=int, help="Level of parallelism") parser.add_argument("-r", "--rank", default=0, type=int, @@ -181,4 +188,4 @@ def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path] if not video_path_list: logging.warning("No videos to process.") else: - process_all_videos(video_path_list, args.output_dir, args.gpu_status) + process_all_videos(video_path_list, args.output_dir, args.step) diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py index 305b9395..9c6265fa 100644 --- a/scripts/train_stage1.py +++ b/scripts/train_stage1.py @@ -53,8 +53,8 @@ from hallo.models.unet_3d import UNet3DConditionModel from hallo.utils.util import (compute_snr, delete_additional_ckpt, import_filename, init_output_dir, - load_checkpoint, save_checkpoint, - seed_everything) + load_checkpoint, move_final_checkpoint, + save_checkpoint, seed_everything) warnings.filterwarnings("ignore") @@ -747,6 +747,12 @@ def train_stage1_process(cfg: argparse.Namespace) -> None: progress_bar.set_postfix(**logs) if global_step >= cfg.solver.max_train_steps: + # process final module weight for stage2 + if accelerator.is_main_process: + move_final_checkpoint(save_dir, module_dir, "reference_unet") + move_final_checkpoint(save_dir, module_dir, "imageproj") + move_final_checkpoint(save_dir, module_dir, "denoising_unet") + move_final_checkpoint(save_dir, module_dir, "face_locator") break accelerator.wait_for_everyone()