diff --git a/inference.py b/inference.py index 3670078..4514b15 100644 --- a/inference.py +++ b/inference.py @@ -44,8 +44,7 @@ def extract_frames_from_video(video_path, save_dir): os.makedirs(save_dir, exist_ok=True) # Construct the ffmpeg command - ffmpeg_command = ["ffmpeg", "-i", video_path, - os.path.join(save_dir, "%06d.png")] + ffmpeg_command = ["ffmpeg", "-i", video_path, os.path.join(save_dir, "%06d.png")] # Run the ffmpeg command subprocess.run( @@ -63,8 +62,7 @@ def extract_frames_from_video(video_path, save_dir): raise ValueError("wrong video path : {}".format(opt.source_video_path)) if not os.path.exists(opt.source_openface_landmark_path): raise ValueError( - "wrong openface stats path : {}".format( - opt.source_openface_landmark_path) + "wrong openface stats path : {}".format(opt.source_openface_landmark_path) ) # extract frames from source video @@ -73,14 +71,12 @@ def extract_frames_from_video(video_path, save_dir): video_frame_dir = opt.source_video_path.replace(".mp4", "") if not os.path.exists(video_frame_dir): os.mkdir(video_frame_dir) - video_size = extract_frames_from_video( - opt.source_video_path, video_frame_dir) + video_size = extract_frames_from_video(opt.source_video_path, video_frame_dir) end_time = time.time() logging.info(f"Frames extraction took {end_time - start_time:.2f} sec.") # extract audio features using Hubert Model from Pytorch - logging.info("extracting audio speech features from : %s", - opt.driving_audio_path) + logging.info("extracting audio speech features from : %s", opt.driving_audio_path) start_time = time.time() ds_feature = feature_extractor.compute_audio_feature( opt.driving_audio_path @@ -97,8 +93,7 @@ def extract_frames_from_video(video_path, save_dir): ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode="edge") end_time = time.time() - logging.info( - f"Audio features extraction took {end_time - start_time:.2f} sec.") + logging.info(f"Audio features extraction took {end_time - start_time:.2f} sec.") # load facial landmarks logging.info( @@ -118,8 +113,7 @@ def extract_frames_from_video(video_path, save_dir): if len(video_frame_path_list) != video_landmark_data.shape[0]: raise ValueError("video frames are misaligned with detected landmarks") video_frame_path_list.sort() - video_frame_path_list_cycle = video_frame_path_list + \ - video_frame_path_list[::-1] + video_frame_path_list_cycle = video_frame_path_list + video_frame_path_list[::-1] video_landmark_data_cycle = np.concatenate( [video_landmark_data, np.flip(video_landmark_data, 0)], 0 ) @@ -158,31 +152,27 @@ def extract_frames_from_video(video_path, save_dir): logging.info("selecting five reference images") ref_img_list = [] resize_w = int(opt.mouth_region_size + opt.mouth_region_size // 4) - resize_h = int((opt.mouth_region_size // 2) * - 3 + opt.mouth_region_size // 8) - ref_index_list = random.sample( - range(5, len(res_video_frame_path_list_pad) - 2), 5) + resize_h = int((opt.mouth_region_size // 2) * 3 + opt.mouth_region_size // 8) + ref_index_list = random.sample(range(5, len(res_video_frame_path_list_pad) - 2), 5) for ref_index in ref_index_list: crop_flag, crop_radius = compute_crop_radius( - video_size, res_video_landmark_data_pad[ref_index - - 5: ref_index, :, :] + video_size, res_video_landmark_data_pad[ref_index - 5 : ref_index, :, :] ) if not crop_flag: raise ValueError( "our method cannot handle videos with large changes in facial size!!" ) crop_radius_1_4 = crop_radius // 4 - ref_img = cv2.imread( - res_video_frame_path_list_pad[ref_index - 3])[:, :, ::-1] + ref_img = cv2.imread(res_video_frame_path_list_pad[ref_index - 3])[:, :, ::-1] ref_landmark = res_video_landmark_data_pad[ref_index - 3, :, :] ref_img_crop = ref_img[ ref_landmark[29, 1] - - crop_radius: ref_landmark[29, 1] + - crop_radius : ref_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, ref_landmark[33, 0] - crop_radius - - crop_radius_1_4: ref_landmark[33, 0] + - crop_radius_1_4 : ref_landmark[33, 0] + crop_radius + crop_radius_1_4, :, @@ -192,21 +182,17 @@ def extract_frames_from_video(video_path, save_dir): ref_img_list.append(ref_img_crop) ref_video_frame = np.concatenate(ref_img_list, 2) ref_img_tensor = ( - torch.from_numpy(ref_video_frame).permute( - 2, 0, 1).unsqueeze(0).float().cuda() + torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() ) # load pretrained model weight - logging.info("loading pretrained model from: %s", - opt.pretrained_clip_DINet_path) - model = DINet(opt.source_channel, opt.ref_channel, - opt.audio_channel).cuda() + logging.info("loading pretrained model from: %s", opt.pretrained_clip_DINet_path) + model = DINet(opt.source_channel, opt.ref_channel, opt.audio_channel).cuda() if not os.path.exists(opt.pretrained_clip_DINet_path): raise ValueError( "wrong path of pretrained model weight: %s", opt.pretrained_clip_DINet_path ) - state_dict = torch.load(opt.pretrained_clip_DINet_path)[ - "state_dict"]["net_g"] + state_dict = torch.load(opt.pretrained_clip_DINet_path)["state_dict"]["net_g"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove module. @@ -224,24 +210,20 @@ def extract_frames_from_video(video_path, save_dir): ) if os.path.exists(res_video_path): os.remove(res_video_path) - res_face_path = res_video_path.replace( - "_facial_dubbing.mp4", "_synthetic_face.mp4") + res_face_path = res_video_path.replace("_facial_dubbing.mp4", "_synthetic_face.mp4") if os.path.exists(res_face_path): os.remove(res_face_path) videowriter = cv2.VideoWriter( res_video_path, cv2.VideoWriter_fourcc(*"XVID"), 25, video_size ) videowriter_face = cv2.VideoWriter( - res_face_path, cv2.VideoWriter_fourcc( - *"XVID"), 25, (resize_w, resize_h) + res_face_path, cv2.VideoWriter_fourcc(*"XVID"), 25, (resize_w, resize_h) ) for clip_end_index in range(5, pad_length, 1): - logging.info("synthesizing frame %d/%d", - clip_end_index - 5, pad_length - 5) + logging.info("synthesizing frame %d/%d", clip_end_index - 5, pad_length - 5) crop_flag, crop_radius = compute_crop_radius( video_size, - res_video_landmark_data_pad[clip_end_index - - 5: clip_end_index, :, :], + res_video_landmark_data_pad[clip_end_index - 5 : clip_end_index, :, :], random_scale=1.05, ) if not crop_flag: @@ -255,12 +237,12 @@ def extract_frames_from_video(video_path, save_dir): frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :] crop_frame_data = frame_data[ frame_landmark[29, 1] - - crop_radius: frame_landmark[29, 1] + - crop_radius : frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, frame_landmark[33, 0] - crop_radius - - crop_radius_1_4: frame_landmark[33, 0] + - crop_radius_1_4 : frame_landmark[33, 0] + crop_radius + crop_radius_1_4, :, @@ -271,9 +253,9 @@ def extract_frames_from_video(video_path, save_dir): ) # [32:224, 32:224, :] crop_frame_data = crop_frame_data / 255.0 crop_frame_data[ - opt.mouth_region_size // 2: opt.mouth_region_size // 2 + opt.mouth_region_size // 2 : opt.mouth_region_size // 2 + opt.mouth_region_size, - opt.mouth_region_size // 8: opt.mouth_region_size // 8 + opt.mouth_region_size // 8 : opt.mouth_region_size // 8 + opt.mouth_region_size, :, ] = 0 @@ -286,29 +268,26 @@ def extract_frames_from_video(video_path, save_dir): .unsqueeze(0) ) deepspeech_tensor = ( - torch.from_numpy( - ds_feature_padding[clip_end_index - 5: clip_end_index, :]) + torch.from_numpy(ds_feature_padding[clip_end_index - 5 : clip_end_index, :]) .permute(1, 0) .unsqueeze(0) .float() .cuda() ) with torch.no_grad(): - pre_frame = model(crop_frame_tensor, - ref_img_tensor, deepspeech_tensor) + pre_frame = model(crop_frame_tensor, ref_img_tensor, deepspeech_tensor) pre_frame = ( - pre_frame.squeeze(0).permute( - 1, 2, 0).detach().cpu().numpy() * 255 + pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255 ) videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8)) pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w, crop_frame_h)) frame_data[ frame_landmark[29, 1] - - crop_radius: frame_landmark[29, 1] + - crop_radius : frame_landmark[29, 1] + crop_radius * 2, frame_landmark[33, 0] - crop_radius - - crop_radius_1_4: frame_landmark[33, 0] + - crop_radius_1_4 : frame_landmark[33, 0] + crop_radius + crop_radius_1_4, :, @@ -324,5 +303,4 @@ def extract_frames_from_video(video_path, save_dir): ) subprocess.call(cmd, shell=True) end_process = default_timer() - logging.info( - f"Video generation took {end_process - start_process:.2f} sec.") + logging.info(f"Video generation took {end_process - start_process:.2f} sec.")