-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathimage_to_video.py
104 lines (88 loc) · 3.48 KB
/
image_to_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video, export_to_gif
from PIL import Image
import numpy as np
import cv2
import rembg
import os
import argparse
def add_margin(pil_img, top, right, bottom, left, color):
width, height = pil_img.size
new_width = width + right + left
new_height = height + top + bottom
result = Image.new(pil_img.mode, (new_width, new_height), color)
result.paste(pil_img, (left, top))
return result
def resize_image(image, output_size=(1024, 576)):
image = image.resize((output_size[1],output_size[1]))
pad_size = (output_size[0]-output_size[1]) //2
image = add_margin(image, 0, pad_size, 0, pad_size, tuple(np.array(image)[0,0]))
return image
def load_image(file, W, H, bg='white'):
# load image
print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
bg_remover = rembg.new_session()
img = rembg.remove(img, session=bg_remover)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
input_mask = img[..., 3:]
# white bg
if bg == 'white':
input_img = img[..., :3] * input_mask + (1 - input_mask)
elif bg == 'black':
input_img = img[..., :3]
else:
raise NotImplementedError
# bgr to rgb
input_img = input_img[..., ::-1].copy()
input_img = Image.fromarray(np.uint8(input_img*255))
return input_img
def load_image_w_bg(file, W, H):
# load image
print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
input_img = img[..., :3]
# bgr to rgb
input_img = input_img[..., ::-1].copy()
input_img = Image.fromarray(np.uint8(input_img*255))
return input_img
def gen_vid(data_path,name, seed, bg, is_pad):
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
# pipe.enable_model_cpu_offload()
pipe.to("cuda")
if is_pad:
height, width = 576, 1024
else:
height, width = 512, 512
save_dir=f"data/{name}_svd/{seed}"
os.makedirs(save_dir,exist_ok=True)
image = load_image(data_path, width, height, bg)
if is_pad:
image = resize_image(image, output_size=(width, height))
generator = torch.manual_seed(seed)
frames = pipe(image, height, width, generator=generator).frames[0]
export_to_video(frames, f"{save_dir}/{name}_generated.mp4", fps=8)
export_to_gif(frames, f"{save_dir}/{name}_generated.gif")
for idx, img in enumerate(frames):
if is_pad:
img = img.crop(((width-height) //2, 0, width - (width-height) //2, height))
img.save(f"{save_dir}/{idx}.png")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--name", type=str, required=True)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--bg", type=str, default='white')
parser.add_argument("--is_pad", type=bool, default=False)
args, extras = parser.parse_known_args()
if args.seed is None:
for seed in range(30):
gen_vid(args.data_path,args.name, seed, args.bg, args.is_pad)
else:
gen_vid(args.data_path,args.name, args.seed, args.bg, args.is_pad)