Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
loky-op committed Aug 15, 2023
1 parent d1ead35 commit cf51160
Showing 1 changed file with 30 additions and 52 deletions.
82 changes: 30 additions & 52 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def extract_frames_from_video(video_path, save_dir):

os.makedirs(save_dir, exist_ok=True)
# Construct the ffmpeg command
ffmpeg_command = ["ffmpeg", "-i", video_path,
os.path.join(save_dir, "%06d.png")]
ffmpeg_command = ["ffmpeg", "-i", video_path, os.path.join(save_dir, "%06d.png")]

# Run the ffmpeg command
subprocess.run(
Expand All @@ -63,8 +62,7 @@ def extract_frames_from_video(video_path, save_dir):
raise ValueError("wrong video path : {}".format(opt.source_video_path))
if not os.path.exists(opt.source_openface_landmark_path):
raise ValueError(
"wrong openface stats path : {}".format(
opt.source_openface_landmark_path)
"wrong openface stats path : {}".format(opt.source_openface_landmark_path)
)

# extract frames from source video
Expand All @@ -73,14 +71,12 @@ def extract_frames_from_video(video_path, save_dir):
video_frame_dir = opt.source_video_path.replace(".mp4", "")
if not os.path.exists(video_frame_dir):
os.mkdir(video_frame_dir)
video_size = extract_frames_from_video(
opt.source_video_path, video_frame_dir)
video_size = extract_frames_from_video(opt.source_video_path, video_frame_dir)
end_time = time.time()
logging.info(f"Frames extraction took {end_time - start_time:.2f} sec.")

# extract audio features using Hubert Model from Pytorch
logging.info("extracting audio speech features from : %s",
opt.driving_audio_path)
logging.info("extracting audio speech features from : %s", opt.driving_audio_path)
start_time = time.time()
ds_feature = feature_extractor.compute_audio_feature(
opt.driving_audio_path
Expand All @@ -97,8 +93,7 @@ def extract_frames_from_video(video_path, save_dir):
ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode="edge")

end_time = time.time()
logging.info(
f"Audio features extraction took {end_time - start_time:.2f} sec.")
logging.info(f"Audio features extraction took {end_time - start_time:.2f} sec.")

# load facial landmarks
logging.info(
Expand All @@ -118,8 +113,7 @@ def extract_frames_from_video(video_path, save_dir):
if len(video_frame_path_list) != video_landmark_data.shape[0]:
raise ValueError("video frames are misaligned with detected landmarks")
video_frame_path_list.sort()
video_frame_path_list_cycle = video_frame_path_list + \
video_frame_path_list[::-1]
video_frame_path_list_cycle = video_frame_path_list + video_frame_path_list[::-1]
video_landmark_data_cycle = np.concatenate(
[video_landmark_data, np.flip(video_landmark_data, 0)], 0
)
Expand Down Expand Up @@ -158,31 +152,27 @@ def extract_frames_from_video(video_path, save_dir):
logging.info("selecting five reference images")
ref_img_list = []
resize_w = int(opt.mouth_region_size + opt.mouth_region_size // 4)
resize_h = int((opt.mouth_region_size // 2) *
3 + opt.mouth_region_size // 8)
ref_index_list = random.sample(
range(5, len(res_video_frame_path_list_pad) - 2), 5)
resize_h = int((opt.mouth_region_size // 2) * 3 + opt.mouth_region_size // 8)
ref_index_list = random.sample(range(5, len(res_video_frame_path_list_pad) - 2), 5)
for ref_index in ref_index_list:
crop_flag, crop_radius = compute_crop_radius(
video_size, res_video_landmark_data_pad[ref_index -
5: ref_index, :, :]
video_size, res_video_landmark_data_pad[ref_index - 5 : ref_index, :, :]
)
if not crop_flag:
raise ValueError(
"our method cannot handle videos with large changes in facial size!!"
)
crop_radius_1_4 = crop_radius // 4
ref_img = cv2.imread(
res_video_frame_path_list_pad[ref_index - 3])[:, :, ::-1]
ref_img = cv2.imread(res_video_frame_path_list_pad[ref_index - 3])[:, :, ::-1]
ref_landmark = res_video_landmark_data_pad[ref_index - 3, :, :]
ref_img_crop = ref_img[
ref_landmark[29, 1]
- crop_radius: ref_landmark[29, 1]
- crop_radius : ref_landmark[29, 1]
+ crop_radius * 2
+ crop_radius_1_4,
ref_landmark[33, 0]
- crop_radius
- crop_radius_1_4: ref_landmark[33, 0]
- crop_radius_1_4 : ref_landmark[33, 0]
+ crop_radius
+ crop_radius_1_4,
:,
Expand All @@ -192,21 +182,17 @@ def extract_frames_from_video(video_path, save_dir):
ref_img_list.append(ref_img_crop)
ref_video_frame = np.concatenate(ref_img_list, 2)
ref_img_tensor = (
torch.from_numpy(ref_video_frame).permute(
2, 0, 1).unsqueeze(0).float().cuda()
torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda()
)

# load pretrained model weight
logging.info("loading pretrained model from: %s",
opt.pretrained_clip_DINet_path)
model = DINet(opt.source_channel, opt.ref_channel,
opt.audio_channel).cuda()
logging.info("loading pretrained model from: %s", opt.pretrained_clip_DINet_path)
model = DINet(opt.source_channel, opt.ref_channel, opt.audio_channel).cuda()
if not os.path.exists(opt.pretrained_clip_DINet_path):
raise ValueError(
"wrong path of pretrained model weight: %s", opt.pretrained_clip_DINet_path
)
state_dict = torch.load(opt.pretrained_clip_DINet_path)[
"state_dict"]["net_g"]
state_dict = torch.load(opt.pretrained_clip_DINet_path)["state_dict"]["net_g"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
Expand All @@ -224,24 +210,20 @@ def extract_frames_from_video(video_path, save_dir):
)
if os.path.exists(res_video_path):
os.remove(res_video_path)
res_face_path = res_video_path.replace(
"_facial_dubbing.mp4", "_synthetic_face.mp4")
res_face_path = res_video_path.replace("_facial_dubbing.mp4", "_synthetic_face.mp4")
if os.path.exists(res_face_path):
os.remove(res_face_path)
videowriter = cv2.VideoWriter(
res_video_path, cv2.VideoWriter_fourcc(*"XVID"), 25, video_size
)
videowriter_face = cv2.VideoWriter(
res_face_path, cv2.VideoWriter_fourcc(
*"XVID"), 25, (resize_w, resize_h)
res_face_path, cv2.VideoWriter_fourcc(*"XVID"), 25, (resize_w, resize_h)
)
for clip_end_index in range(5, pad_length, 1):
logging.info("synthesizing frame %d/%d",
clip_end_index - 5, pad_length - 5)
logging.info("synthesizing frame %d/%d", clip_end_index - 5, pad_length - 5)
crop_flag, crop_radius = compute_crop_radius(
video_size,
res_video_landmark_data_pad[clip_end_index -
5: clip_end_index, :, :],
res_video_landmark_data_pad[clip_end_index - 5 : clip_end_index, :, :],
random_scale=1.05,
)
if not crop_flag:
Expand All @@ -255,12 +237,12 @@ def extract_frames_from_video(video_path, save_dir):
frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :]
crop_frame_data = frame_data[
frame_landmark[29, 1]
- crop_radius: frame_landmark[29, 1]
- crop_radius : frame_landmark[29, 1]
+ crop_radius * 2
+ crop_radius_1_4,
frame_landmark[33, 0]
- crop_radius
- crop_radius_1_4: frame_landmark[33, 0]
- crop_radius_1_4 : frame_landmark[33, 0]
+ crop_radius
+ crop_radius_1_4,
:,
Expand All @@ -271,9 +253,9 @@ def extract_frames_from_video(video_path, save_dir):
) # [32:224, 32:224, :]
crop_frame_data = crop_frame_data / 255.0
crop_frame_data[
opt.mouth_region_size // 2: opt.mouth_region_size // 2
opt.mouth_region_size // 2 : opt.mouth_region_size // 2
+ opt.mouth_region_size,
opt.mouth_region_size // 8: opt.mouth_region_size // 8
opt.mouth_region_size // 8 : opt.mouth_region_size // 8
+ opt.mouth_region_size,
:,
] = 0
Expand All @@ -286,29 +268,26 @@ def extract_frames_from_video(video_path, save_dir):
.unsqueeze(0)
)
deepspeech_tensor = (
torch.from_numpy(
ds_feature_padding[clip_end_index - 5: clip_end_index, :])
torch.from_numpy(ds_feature_padding[clip_end_index - 5 : clip_end_index, :])
.permute(1, 0)
.unsqueeze(0)
.float()
.cuda()
)
with torch.no_grad():
pre_frame = model(crop_frame_tensor,
ref_img_tensor, deepspeech_tensor)
pre_frame = model(crop_frame_tensor, ref_img_tensor, deepspeech_tensor)
pre_frame = (
pre_frame.squeeze(0).permute(
1, 2, 0).detach().cpu().numpy() * 255
pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255
)
videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8))
pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w, crop_frame_h))
frame_data[
frame_landmark[29, 1]
- crop_radius: frame_landmark[29, 1]
- crop_radius : frame_landmark[29, 1]
+ crop_radius * 2,
frame_landmark[33, 0]
- crop_radius
- crop_radius_1_4: frame_landmark[33, 0]
- crop_radius_1_4 : frame_landmark[33, 0]
+ crop_radius
+ crop_radius_1_4,
:,
Expand All @@ -324,5 +303,4 @@ def extract_frames_from_video(video_path, save_dir):
)
subprocess.call(cmd, shell=True)
end_process = default_timer()
logging.info(
f"Video generation took {end_process - start_process:.2f} sec.")
logging.info(f"Video generation took {end_process - start_process:.2f} sec.")

0 comments on commit cf51160

Please sign in to comment.