From fc2d31abbb2f6ede6c9ebd9d8a165d8f71541e68 Mon Sep 17 00:00:00 2001 From: xumingw Date: Thu, 27 Jun 2024 00:36:21 +0800 Subject: [PATCH] feat: data preprocessing code of hallo * add data preprocessing * add utils functions of data preprocessing * add image processor and audio processor of data preprocessing --- hallo/datasets/audio_processor.py | 5 +- hallo/datasets/image_processor.py | 139 ++++++++++++++++++- hallo/utils/util.py | 207 ++++++++++++++++++++++++++-- scripts/data_preprocess.py | 184 +++++++++++++++++++++++++ scripts/extract_meta_info_stage1.py | 100 ++++++++++++++ scripts/extract_meta_info_stage2.py | 192 ++++++++++++++++++++++++++ 6 files changed, 814 insertions(+), 13 deletions(-) create mode 100644 scripts/data_preprocess.py create mode 100644 scripts/extract_meta_info_stage1.py create mode 100644 scripts/extract_meta_info_stage2.py diff --git a/hallo/datasets/audio_processor.py b/hallo/datasets/audio_processor.py index 50738970..f340a52f 100644 --- a/hallo/datasets/audio_processor.py +++ b/hallo/datasets/audio_processor.py @@ -73,7 +73,7 @@ def __init__( self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) - def preprocess(self, wav_file: str, clip_length: int): + def preprocess(self, wav_file: str, clip_length: int=-1): """ Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. The separated vocal track is then converted into wav2vec2 for further processing or analysis. @@ -109,7 +109,8 @@ def preprocess(self, wav_file: str, clip_length: int): audio_length = seq_len audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) - if seq_len % clip_length != 0: + + if clip_length>0 and seq_len % clip_length != 0: audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) seq_len += clip_length - seq_len % clip_length audio_feature = audio_feature.unsqueeze(0) diff --git a/hallo/datasets/image_processor.py b/hallo/datasets/image_processor.py index 2c093907..57715d18 100644 --- a/hallo/datasets/image_processor.py +++ b/hallo/datasets/image_processor.py @@ -1,3 +1,4 @@ +# pylint: disable=W0718 """ This module is responsible for processing images, particularly for face-related tasks. It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like @@ -8,13 +9,15 @@ from typing import List import cv2 +import mediapipe as mp import numpy as np import torch from insightface.app import FaceAnalysis from PIL import Image from torchvision import transforms -from ..utils.util import get_mask +from ..utils.util import (blur_mask, get_landmark_overframes, get_mask, + get_union_face_mask, get_union_lip_mask) MEAN = 0.5 STD = 0.5 @@ -207,3 +210,137 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): self.close() + + +class ImageProcessorForDataProcessing(): + """ + ImageProcessor is a class responsible for processing images, particularly for face-related tasks. + It takes in an image and performs various operations such as augmentation, face detection, + face embedding extraction, and rendering a face mask. The processed images are then used for + further analysis or recognition purposes. + + Attributes: + img_size (int): The size of the image to be processed. + face_analysis_model_path (str): The path to the face analysis model. + + Methods: + preprocess(source_image_path, cache_dir): + Preprocesses the input image by performing augmentation, face detection, + face embedding extraction, and rendering a face mask. + + close(): + Closes the ImageProcessor and releases any resources being used. + + _augmentation(images, transform, state=None): + Applies image augmentation to the input images using the given transform and state. + + __enter__(): + Enters a runtime context and returns the ImageProcessor object. + + __exit__(_exc_type, _exc_val, _exc_tb): + Exits a runtime context and handles any exceptions that occurred during the processing. + """ + def __init__(self, face_analysis_model_path, landmark_model_path, gpu_status) -> None: + if gpu_status: + self.face_analysis = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) + self.landmarker = None + else: + BaseOptions = mp.tasks.BaseOptions + FaceLandmarker = mp.tasks.vision.FaceLandmarker + FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions + VisionRunningMode = mp.tasks.vision.RunningMode + # Create a face landmarker instance with the video mode: + options = FaceLandmarkerOptions( + base_options=BaseOptions(model_asset_path=landmark_model_path), + running_mode=VisionRunningMode.IMAGE, + ) + self.landmarker = FaceLandmarker.create_from_options(options) + self.face_analysis = None + + def preprocess(self, source_image_path: str): + """ + Apply preprocessing to the source image to prepare for face analysis. + + Parameters: + source_image_path (str): The path to the source image. + cache_dir (str): The directory to cache intermediate results. + + Returns: + None + """ + # 1. get face embdeding + face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None + if self.face_analysis: + for frame in sorted(os.listdir(source_image_path)): + try: + source_image = Image.open( + os.path.join(source_image_path, frame)) + ref_image_pil = source_image.convert("RGB") + # 2.1 detect face + faces = self.face_analysis.get(cv2.cvtColor( + np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) + # use max size face + face = sorted(faces, key=lambda x: ( + x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1] + # 2.2 face embedding + face_emb = face["embedding"] + if face_emb is not None: + break + except Exception as _: + continue + + if self.landmarker: + # 3.1 get landmark + landmarks, height, width = get_landmark_overframes( + self.landmarker, source_image_path) + assert len(landmarks) == len(os.listdir(source_image_path)) + + # 3 render face and lip mask + face_mask = get_union_face_mask(landmarks, height, width) + lip_mask = get_union_lip_mask(landmarks, height, width) + + # 4 gaussian blur + blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51)) + blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31)) + + # 5 seperate mask + sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask) + sep_pose_mask = 255.0 - blur_face_mask + sep_lip_mask = blur_lip_mask + + return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask + + def close(self): + """ + Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. + + Args: + self: The ImageProcessor instance. + + Returns: + None. + """ + for _, model in self.face_analysis.models.items(): + if hasattr(model, "Dispose"): + model.Dispose() + + def _augmentation(self, images, transform, state=None): + if state is not None: + torch.set_rng_state(state) + if isinstance(images, List): + transformed_images = [transform(img) for img in images] + ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) + else: + ret_tensor = transform(images) # (c, h, w) + return ret_tensor + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() diff --git a/hallo/utils/util.py b/hallo/utils/util.py index 9dc61fb8..2598710e 100644 --- a/hallo/utils/util.py +++ b/hallo/utils/util.py @@ -1,6 +1,7 @@ # pylint: disable=C0116 # pylint: disable=W0718 # pylint: disable=R1732 +# pylint: disable=R0801 """ utils.py @@ -378,7 +379,32 @@ def get_landmark(file): return np.array(face_landmark), height, width -def get_lip_mask(landmarks, height, width, out_path): +def get_landmark_overframes(landmark_model, frames_path): + """ + This function iterate frames and returns the facial landmarks detected in each frame. + + Args: + landmark_model: mediapipe landmark model instance + frames_path (str): The path to the video frames. + + Returns: + List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks. + """ + + face_landmarks = [] + + for file in sorted(os.listdir(frames_path)): + image = mp.Image.create_from_file(os.path.join(frames_path, file)) + height, width = image.height, image.width + landmarker_result = landmark_model.detect(image) + frame_landmark = compute_face_landmarks( + landmarker_result, height, width) + face_landmarks.append(frame_landmark) + + return face_landmarks, height, width + + +def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0): """ Extracts the lip region from the given landmarks and saves it as an image. @@ -387,19 +413,42 @@ def get_lip_mask(landmarks, height, width, out_path): height (int): Height of the output lip mask image. width (int): Width of the output lip mask image. out_path (pathlib.Path): Path to save the lip mask image. + expand_ratio (float): Expand ratio of mask. """ lip_landmarks = np.take(landmarks, lip_ids, 0) min_xy_lip = np.round(np.min(lip_landmarks, 0)) max_xy_lip = np.round(np.max(lip_landmarks, 0)) min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region( - [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, 2.0) + [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio) lip_mask = np.zeros((height, width), dtype=np.uint8) lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]), round(min_xy_lip[0]):round(max_xy_lip[0])] = 255 - cv2.imwrite(str(out_path), lip_mask) + if out_path: + cv2.imwrite(str(out_path), lip_mask) + return None + + return lip_mask + + +def get_union_lip_mask(landmarks, height, width, expand_ratio=1): + """ + Extracts the lip region from the given landmarks and saves it as an image. + + Parameters: + landmarks (numpy.ndarray): Array of facial landmarks. + height (int): Height of the output lip mask image. + width (int): Width of the output lip mask image. + expand_ratio (float): Expand ratio of mask. + """ + lip_masks = [] + for landmark in landmarks: + lip_masks.append(get_lip_mask(landmarks=landmark, height=height, + width=width, expand_ratio=expand_ratio)) + union_mask = get_union_mask(lip_masks) + return union_mask -def get_face_mask(landmarks, height, width, out_path, expand_ratio): +def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2): """ Generate a face mask based on the given landmarks. @@ -408,7 +457,7 @@ def get_face_mask(landmarks, height, width, out_path, expand_ratio): height (int): The height of the output face mask image. width (int): The width of the output face mask image. out_path (pathlib.Path): The path to save the face mask image. - + expand_ratio (float): Expand ratio of mask. Returns: None. The face mask image is saved at the specified path. """ @@ -420,8 +469,30 @@ def get_face_mask(landmarks, height, width, out_path, expand_ratio): face_mask = np.zeros((height, width), dtype=np.uint8) face_mask[round(min_xy_face[1]):round(max_xy_face[1]), round(min_xy_face[0]):round(max_xy_face[0])] = 255 - cv2.imwrite(str(out_path), face_mask) + if out_path: + cv2.imwrite(str(out_path), face_mask) + return None + return face_mask + + +def get_union_face_mask(landmarks, height, width, expand_ratio=1): + """ + Generate a face mask based on the given landmarks. + + Args: + landmarks (numpy.ndarray): The landmarks of the face. + height (int): The height of the output face mask image. + width (int): The width of the output face mask image. + expand_ratio (float): Expand ratio of mask. + Returns: + None. The face mask image is saved at the specified path. + """ + face_masks = [] + for landmark in landmarks: + face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio)) + union_mask = get_union_mask(face_masks) + return union_mask def get_mask(file, cache_dir, face_expand_raio): """ @@ -507,6 +578,25 @@ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size= mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) # Check if the image is loaded successfully + if mask is not None: + normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size) + # Save the normalized mask image + cv2.imwrite(output_file_path, normalized_mask) + return f"Processed, normalized, and saved: {output_file_path}" + return f"Failed to load image: {file_path}" + + +def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)): + """ + Read, resize, blur, normalize, and save an image. + + Parameters: + file_path (str): Path to the input image file. + resize_dim (tuple): Dimensions to resize the images to. + kernel_size (tuple): Size of the kernel to use for Gaussian blur. + """ + # Check if the image is loaded successfully + normalized_mask = None if mask is not None: # Resize the mask image resized_mask = cv2.resize(mask, resize_dim) @@ -516,10 +606,7 @@ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size= normalized_mask = cv2.normalize( blurred_mask, None, 0, 255, cv2.NORM_MINMAX) # Save the normalized mask image - cv2.imwrite(output_file_path, normalized_mask) - return f"Processed, normalized, and saved: {output_file_path}" - return f"Failed to load image: {file_path}" - + return normalized_mask def get_background_mask(file_path, output_file_path): """ @@ -762,3 +849,103 @@ def compute_snr(noise_scheduler, timesteps): # Compute SNR. snr = (alpha / sigma) ** 2 return snr + +def extract_audio_from_videos(video_path: Path, output_dir: Path) -> Path: + """ + Extract audio from a video file and save it as a WAV file. + + This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file + in the specified output directory. + + Args: + video_path (Path): The path to the input video file. + output_dir (Path): The directory where the extracted audio file will be saved. + + Returns: + Path: The path to the extracted audio file. + + Raises: + subprocess.CalledProcessError: If the ffmpeg command fails to execute. + """ + audio_output_dir = output_dir / 'audios' + audio_output_dir.mkdir(parents=True, exist_ok=True) + audio_output_path = audio_output_dir / f'{video_path.stem}.wav' + + ffmpeg_command = [ + 'ffmpeg', '-y', + '-i', str(video_path), + '-vn', '-acodec', + "pcm_s16le", '-ar', '16000', '-ac', '2', + str(audio_output_path) + ] + + try: + print(f"Running command: {' '.join(ffmpeg_command)}") + subprocess.run(ffmpeg_command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error extracting audio from video: {e}") + raise + + return audio_output_path + + +def convert_video_to_images(video_path: Path, output_dir: Path) -> Path: + """ + Convert a video file into a sequence of images. + + This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved + in a directory named after the video file stem under the specified output directory. + + Args: + video_path (Path): The path to the input video file. + output_dir (Path): The directory where the extracted images will be saved. + + Returns: + Path: The path to the directory containing the extracted images. + + Raises: + subprocess.CalledProcessError: If the ffmpeg command fails to execute. + """ + images_output_dir = output_dir / 'images' / video_path.stem + images_output_dir.mkdir(parents=True, exist_ok=True) + + ffmpeg_command = [ + 'ffmpeg', + '-i', str(video_path), + '-vf', 'fps=25', + str(images_output_dir / '%04d.png') + ] + + try: + print(f"Running command: {' '.join(ffmpeg_command)}") + subprocess.run(ffmpeg_command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error converting video to images: {e}") + raise + + return images_output_dir + + +def get_union_mask(masks): + union_mask = None + for mask in masks: + if union_mask is None: + union_mask = mask + else: + union_mask = np.maximum(union_mask, mask) + + if union_mask is not None: + # Find the bounding box of the non-zero regions in the mask + rows = np.any(union_mask, axis=1) + cols = np.any(union_mask, axis=0) + try: + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + except Exception as e: + print(str(e)) + return 0.0 + + # Set bounding box area to white + union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask) + + return union_mask diff --git a/scripts/data_preprocess.py b/scripts/data_preprocess.py new file mode 100644 index 00000000..7dce91c7 --- /dev/null +++ b/scripts/data_preprocess.py @@ -0,0 +1,184 @@ +# pylint: disable=W1203,W0718 +""" +This module is used to process videos to prepare data for training. It utilizes various libraries and models +to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction. +The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism, +and rank for distributed processing. + +Usage: + python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0 + +Example: + python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0 +""" +import argparse +import logging +import os +from pathlib import Path +from typing import List + +import cv2 +import torch +from tqdm import tqdm + +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessorForDataProcessing +from hallo.utils.util import convert_video_to_images, extract_audio_from_videos + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + + +def setup_directories(video_path: Path) -> dict: + """ + Setup directories for storing processed files. + + Args: + video_path (Path): Path to the video file. + + Returns: + dict: A dictionary containing paths for various directories. + """ + base_dir = video_path.parent.parent + dirs = { + "face_mask": base_dir / "face_mask", + "sep_pose_mask": base_dir / "sep_pose_mask", + "sep_face_mask": base_dir / "sep_face_mask", + "sep_lip_mask": base_dir / "sep_lip_mask", + "face_emb": base_dir / "face_emb", + "audio_emb": base_dir / "audio_emb" + } + + for path in dirs.values(): + path.mkdir(parents=True, exist_ok=True) + + return dirs + + +def process_single_video(video_path: Path, + output_dir: Path, + image_processor: ImageProcessorForDataProcessing, + audio_processor: AudioProcessor, + gpu_status: bool) -> None: + """ + Process a single video file. + + Args: + video_path (Path): Path to the video file. + output_dir (Path): Directory to save the output. + image_processor (ImageProcessorForDataProcessing): Image processor object. + audio_processor (AudioProcessor): Audio processor object. + gpu_status (bool): Whether to use GPU for processing. + """ + assert video_path.exists(), f"Video path {video_path} does not exist" + dirs = setup_directories(video_path) + logging.info(f"Processing video: {video_path}") + + try: + if not gpu_status: + images_dir = convert_video_to_images(video_path, output_dir) + logging.info(f"Images saved to: {images_dir}") + + audio_path = extract_audio_from_videos(video_path, output_dir) + logging.info(f"Audio extracted to: {audio_path}") + + face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess( + images_dir) + cv2.imwrite( + str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask) + cv2.imwrite(str(dirs["sep_pose_mask"] / + f"{video_path.stem}.png"), sep_pose_mask) + cv2.imwrite(str(dirs["sep_face_mask"] / + f"{video_path.stem}.png"), sep_face_mask) + cv2.imwrite(str(dirs["sep_lip_mask"] / + f"{video_path.stem}.png"), sep_lip_mask) + else: + images_dir = output_dir / "images" / video_path.stem + audio_path = output_dir / "audios" / f"{video_path.stem}.wav" + _, face_emb, _, _, _ = image_processor.preprocess(images_dir) + torch.save(face_emb, str( + dirs["face_emb"] / f"{video_path.stem}.pt")) + audio_emb, _ = audio_processor.preprocess(audio_path) + torch.save(audio_emb, str( + dirs["audio_emb"] / f"{video_path.stem}.pt")) + except Exception as e: + logging.error(f"Failed to process video {video_path}: {e}") + + +def process_all_videos(input_video_list: List[Path], output_dir: Path, gpu_status: bool) -> None: + """ + Process all videos in the input list. + + Args: + input_video_list (List[Path]): List of video paths to process. + output_dir (Path): Directory to save the output. + gpu_status (bool): Whether to use GPU for processing. + """ + face_analysis_model_path = "pretrained_models/face_analysis" + landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" + audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx" + wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h' + + audio_processor = AudioProcessor( + 16000, + 25, + wav2vec_model_path, + False, + os.path.dirname(audio_separator_model_file), + os.path.basename(audio_separator_model_file), + os.path.join(output_dir, "vocals"), + ) if gpu_status else None + + image_processor = ImageProcessorForDataProcessing( + face_analysis_model_path, landmark_model_path, gpu_status) + + for video_path in tqdm(input_video_list, desc="Processing videos"): + process_single_video(video_path, output_dir, + image_processor, audio_processor, gpu_status) + + +def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]: + """ + Get paths of videos to process, partitioned for parallel processing. + + Args: + source_dir (Path): Source directory containing videos. + parallelism (int): Level of parallelism. + rank (int): Rank for distributed processing. + + Returns: + List[Path]: List of video paths to process. + """ + video_paths = [item for item in sorted( + source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4'] + return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process videos to prepare data for training. Run this script twice with different GPU status parameters." + ) + parser.add_argument("-i", "--input_dir", type=Path, + required=True, help="Directory containing videos") + parser.add_argument("-o", "--output_dir", type=Path, + help="Directory to save results, default is parent dir of input dir") + parser.add_argument("-g", "--gpu_status", action='store_true', + help="Run tasks requiring GPU or tasks not requiring GPU") + parser.add_argument("-p", "--parallelism", default=1, + type=int, help="Level of parallelism") + parser.add_argument("-r", "--rank", default=0, type=int, + help="Rank for distributed processing") + + args = parser.parse_args() + + if args.output_dir is None: + args.output_dir = args.input_dir.parent + + video_path_list = get_video_paths( + args.input_dir, args.parallelism, args.rank) + + if not video_path_list: + logging.warning("No videos to process.") + else: + process_all_videos(video_path_list, args.output_dir, args.gpu_status) diff --git a/scripts/extract_meta_info_stage1.py b/scripts/extract_meta_info_stage1.py new file mode 100644 index 00000000..d25123e1 --- /dev/null +++ b/scripts/extract_meta_info_stage1.py @@ -0,0 +1,100 @@ +# pylint: disable=R0801 +""" +This module is used to extract meta information from video directories. + +It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path` +specifies the path to the video directory, while the `dataset_name` specifies the name +of the dataset. The module then collects all the video folder paths, and for each video +folder, it checks if a mask path and a face embedding path exist. If they do, it appends +a dictionary containing the image path, mask path, and face embedding path to a list. + +Finally, the module writes the list of dictionaries to a JSON file with the filename +constructed using the `dataset_name`. + +Usage: + python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf + +""" + +import argparse +import json +import os +from pathlib import Path + + +def collect_video_folder_paths(root_path: Path) -> list: + """ + Collect all video folder paths from the root path. + + Args: + root_path (Path): The root directory containing video folders. + + Returns: + list: List of video folder paths. + """ + return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()] + + +def construct_meta_info(frames_dir_path: Path) -> dict: + """ + Construct meta information for a given frames directory. + + Args: + frames_dir_path (Path): The path to the frames directory. + + Returns: + dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist. + """ + mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png" + face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt" + + if not os.path.exists(mask_path): + print(f"Mask path not found: {mask_path}") + return None + + return { + "image_path": str(frames_dir_path), + "mask_path": mask_path, + "face_emb": face_emb_path, + } + + +def main(): + """ + Main function to extract meta info for training. + """ + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--root_path", type=str, + required=True, help="Root path of the video directories") + parser.add_argument("-n", "--dataset_name", type=str, + required=True, help="Name of the dataset") + parser.add_argument("--meta_info_name", type=str, + help="Name of the meta information file") + + args = parser.parse_args() + + if args.meta_info_name is None: + args.meta_info_name = args.dataset_name + + image_dir = Path(args.root_path) / "images" + output_dir = Path("./data") + output_dir.mkdir(exist_ok=True) + + # Collect all video folder paths + frames_dir_paths = collect_video_folder_paths(image_dir) + + meta_infos = [] + for frames_dir_path in frames_dir_paths: + meta_info = construct_meta_info(frames_dir_path) + if meta_info: + meta_infos.append(meta_info) + + output_file = output_dir / f"{args.meta_info_name}_stage1.json" + with output_file.open("w", encoding="utf-8") as f: + json.dump(meta_infos, f, indent=4) + + print(f"Final data count: {len(meta_infos)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/extract_meta_info_stage2.py b/scripts/extract_meta_info_stage2.py new file mode 100644 index 00000000..e2d9301c --- /dev/null +++ b/scripts/extract_meta_info_stage2.py @@ -0,0 +1,192 @@ +# pylint: disable=R0801 +""" +This module is used to extract meta information from video files and store them in a JSON file. + +The script takes in command line arguments to specify the root path of the video files, +the dataset name, and the name of the meta information file. It then generates a list of +dictionaries containing the meta information for each video file and writes it to a JSON +file with the specified name. + +The meta information includes the path to the video file, the mask path, the face mask +path, the face mask union path, the face mask gaussian path, the lip mask path, the lip +mask union path, the lip mask gaussian path, the separate mask border, the separate mask +face, the separate mask lip, the face embedding path, the audio path, the vocals embedding +base last path, the vocals embedding base all path, the vocals embedding base average +path, the vocals embedding large last path, the vocals embedding large all path, and the +vocals embedding large average path. + +The script checks if the mask path exists before adding the information to the list. + +Usage: + python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name + +Example: + python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info +""" + +import argparse +import json +import os +from pathlib import Path + +import torch +from decord import VideoReader, cpu +from tqdm import tqdm + + +def get_video_paths(root_path: Path, extensions: list) -> list: + """ + Get a list of video paths from the root path with the specified extensions. + + Args: + root_path (Path): The root directory containing video files. + extensions (list): List of file extensions to include. + + Returns: + list: List of video file paths. + """ + return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions] + + +def file_exists(file_path: str) -> bool: + """ + Check if a file exists. + + Args: + file_path (str): The path to the file. + + Returns: + bool: True if the file exists, False otherwise. + """ + return os.path.exists(file_path) + + +def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str: + """ + Construct a new path by replacing the base directory and extension in the original path. + + Args: + video_path (str): The original video path. + base_dir (str): The base directory to be replaced. + new_dir (str): The new directory to replace the base directory. + new_ext (str): The new file extension. + + Returns: + str: The constructed path. + """ + return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext) + + +def extract_meta_info(video_path: str) -> dict: + """ + Extract meta information for a given video file. + + Args: + video_path (str): The path to the video file. + + Returns: + dict: A dictionary containing the meta information for the video. + """ + mask_path = construct_paths( + video_path, "videos", "face_mask", ".png") + sep_mask_border = construct_paths( + video_path, "videos", "sep_pose_mask", ".png") + sep_mask_face = construct_paths( + video_path, "videos", "sep_face_mask", ".png") + sep_mask_lip = construct_paths( + video_path, "videos", "sep_lip_mask", ".png") + face_emb_path = construct_paths( + video_path, "videos", "face_emb", ".pt") + audio_path = construct_paths(video_path, "videos", "audios", ".wav") + vocal_emb_base_all = construct_paths( + video_path, "videos", "audio_emb", ".pt") + + assert_flag = True + + if not file_exists(mask_path): + print(f"Mask path not found: {mask_path}") + assert_flag = False + if not file_exists(sep_mask_border): + print(f"Separate mask border not found: {sep_mask_border}") + assert_flag = False + if not file_exists(sep_mask_face): + print(f"Separate mask face not found: {sep_mask_face}") + assert_flag = False + if not file_exists(sep_mask_lip): + print(f"Separate mask lip not found: {sep_mask_lip}") + assert_flag = False + if not file_exists(face_emb_path): + print(f"Face embedding path not found: {face_emb_path}") + assert_flag = False + if not file_exists(audio_path): + print(f"Audio path not found: {audio_path}") + assert_flag = False + if not file_exists(vocal_emb_base_all): + print(f"Vocal embedding base all not found: {vocal_emb_base_all}") + assert_flag = False + + video_frames = VideoReader(video_path, ctx=cpu(0)) + audio_emb = torch.load(vocal_emb_base_all) + if abs(len(video_frames) - audio_emb.shape[0]) > 3: + print(f"Frame count mismatch for video: {video_path}") + assert_flag = False + + face_emb = torch.load(face_emb_path) + if face_emb is None: + print(f"Face embedding is None for video: {video_path}") + assert_flag = False + + del video_frames, audio_emb + + if assert_flag: + return { + "video_path": str(video_path), + "mask_path": mask_path, + "sep_mask_border": sep_mask_border, + "sep_mask_face": sep_mask_face, + "sep_mask_lip": sep_mask_lip, + "face_emb_path": face_emb_path, + "audio_path": audio_path, + "vocals_emb_base_all": vocal_emb_base_all, + } + return None + + +def main(): + """ + Main function to extract meta info for training. + """ + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--root_path", type=str, + required=True, help="Root path of the video files") + parser.add_argument("-n", "--dataset_name", type=str, + required=True, help="Name of the dataset") + parser.add_argument("--meta_info_name", type=str, + help="Name of the meta information file") + + args = parser.parse_args() + + if args.meta_info_name is None: + args.meta_info_name = args.dataset_name + + video_dir = Path(args.root_path) / "videos" + video_paths = get_video_paths(video_dir, [".mp4"]) + + meta_infos = [] + + for video_path in tqdm(video_paths, desc="Extracting meta info"): + meta_info = extract_meta_info(video_path) + if meta_info: + meta_infos.append(meta_info) + + print(f"Final data count: {len(meta_infos)}") + + output_file = Path(f"./data/{args.meta_info_name}_stage2.json") + output_file.parent.mkdir(parents=True, exist_ok=True) + + with output_file.open("w", encoding="utf-8") as f: + json.dump(meta_infos, f, indent=4) + + +if __name__ == "__main__": + main()