forked from open-mmlab/mmagic
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_on_video_with_window.py
99 lines (84 loc) · 3.53 KB
/
run_on_video_with_window.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
import glob
import os.path as osp
from tqdm import tqdm
import mmcv
import numpy as np
import cv2
from tqdm import tqdm
import torch
from mmedit.datasets.pipelines import Compose
from mmedit.apis import init_model, restoration_video_inference
from mmedit.core import tensor2img
from mmedit.utils import modify_args
def parse_args():
modify_args()
parser = argparse.ArgumentParser(description='Restoration demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('input_dir', help='directory of the input video')
parser.add_argument('output_dir', help='directory of the output video')
parser.add_argument(
'--max-seq-len',
type=int,
default=10,
help='maximum sequence length if recurrent framework is used')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
args = parser.parse_args()
return args
def main():
"""Demo for video restoration models.
Note that we accept video as input/output, when 'input_dir'/'output_dir' is
set to the path to the video. But using videos introduces video
compression, which lowers the visual quality. If you want actual quality,
please save them as separate images (.jpg).
"""
args = parse_args()
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)
model = init_model(args.config, args.checkpoint, device=device)
device = next(model.parameters()).device
if model.cfg.get('demo_pipeline', None):
test_pipeline = model.cfg.demo_pipeline
elif model.cfg.get('test_pipeline', None):
test_pipeline = model.cfg.test_pipeline
else:
test_pipeline = model.cfg.val_pipeline
tmp_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] not in [
'GenerateSegmentIndices', 'LoadImageFromFileList'
]:
tmp_pipeline.append(pipeline)
test_pipeline = tmp_pipeline
test_pipeline = Compose(test_pipeline)
video_reader = mmcv.VideoReader(args.input_dir)
frame_count = video_reader.frame_cnt
fps = video_reader.fps
# fourcc = cv2.VideoWriter_fourcc('i', 'Y', 'U', 'V')
#video_writer = cv2.VideoWriter(args.output_dir, fourcc, video_reader.fps, (video_reader.width * 4, video_reader.height * 4))
with torch.no_grad():
for i in tqdm(range(0, frame_count, args.max_seq_len)):
data = dict(lq=[], lq_path=None, key="")
frames = video_reader[i:i + args.max_seq_len]
# print(frames.shape)
data['lq'] = [np.flip(frame, axis=2) for frame in frames]
data = test_pipeline(data)
data = data['lq'].unsqueeze(0)
try:
result = model(lq=data.to(device), test_mode=True)['output'].cpu()[0]
for k,frame in enumerate(result):
output = tensor2img(frame)
#video_writer.write(output.astype(np.uint8))
# write image with 8 zeros padding
cv2.imwrite(osp.join(args.output_dir, f"{i+k:08d}.jpg"), output)
except:
print("Error in frame ", i)
continue
# run ffmpeg to convert images to video
os.system(f"ffmpeg -y -r {fps} -i {osp.join(args.output_dir, '%08d.jpg')} -c:v libx264 -pix_fmt yuv420p -r {fps} {args.input_dir}.out.mp4")
if __name__ == '__main__':
main()