crrodrvi commited on
Commit
ca3a9a3
·
verified ·
1 Parent(s): 8ca3114

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +144 -0
  2. vox-adv-cpk.pth +3 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+
4
+ import yaml
5
+ from tqdm import tqdm
6
+
7
+ import imageio
8
+ import numpy as np
9
+ from skimage.transform import resize
10
+ from skimage import img_as_ubyte
11
+ import torch
12
+ from sync_batchnorm import DataParallelWithCallback
13
+
14
+ from modules.generator import OcclusionAwareGenerator
15
+ from modules.keypoint_detector import KPDetector
16
+ from animate import normalize_kp
17
+
18
+
19
+ def load_checkpoints(config_path, checkpoint_path, cpu=False):
20
+
21
+ with open(config_path) as f:
22
+ config = yaml.load(f)
23
+
24
+ generator = OcclusionAwareGenerator(
25
+ **config["model_params"]["generator_params"], **config["model_params"]["common_params"]
26
+ )
27
+ if not cpu:
28
+ generator.cuda()
29
+
30
+ kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"])
31
+ if not cpu:
32
+ kp_detector.cuda()
33
+
34
+ if cpu:
35
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
36
+ else:
37
+ checkpoint = torch.load(checkpoint_path)
38
+
39
+ generator.load_state_dict(checkpoint["generator"])
40
+ kp_detector.load_state_dict(checkpoint["kp_detector"])
41
+
42
+ if not cpu:
43
+ generator = DataParallelWithCallback(generator)
44
+ kp_detector = DataParallelWithCallback(kp_detector)
45
+
46
+ generator.eval()
47
+ kp_detector.eval()
48
+
49
+ return generator, kp_detector
50
+
51
+
52
+ def make_animation(
53
+ source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False
54
+ ):
55
+ with torch.no_grad():
56
+ predictions = []
57
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
58
+ if not cpu:
59
+ source = source.cuda()
60
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
61
+ kp_source = kp_detector(source)
62
+ kp_driving_initial = kp_detector(driving[:, :, 0])
63
+
64
+ for frame_idx in tqdm(range(driving.shape[2])):
65
+ driving_frame = driving[:, :, frame_idx]
66
+ if not cpu:
67
+ driving_frame = driving_frame.cuda()
68
+ kp_driving = kp_detector(driving_frame)
69
+ kp_norm = normalize_kp(
70
+ kp_source=kp_source,
71
+ kp_driving=kp_driving,
72
+ kp_driving_initial=kp_driving_initial,
73
+ use_relative_movement=relative,
74
+ use_relative_jacobian=relative,
75
+ adapt_movement_scale=adapt_movement_scale,
76
+ )
77
+ out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
78
+
79
+ predictions.append(np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0])
80
+ return predictions
81
+
82
+
83
+ def inference(video, image):
84
+ # trim video to 8 seconds
85
+ cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4"
86
+ subprocess.run(cmd.split())
87
+ video = "video_input.mp4"
88
+
89
+ source_image = imageio.imread(image)
90
+ reader = imageio.get_reader(video)
91
+ fps = reader.get_meta_data()["fps"]
92
+ driving_video = []
93
+ try:
94
+ for im in reader:
95
+ driving_video.append(im)
96
+ except RuntimeError:
97
+ pass
98
+ reader.close()
99
+
100
+ source_image = resize(source_image, (256, 256))[..., :3]
101
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
102
+
103
+ predictions = make_animation(
104
+ source_image,
105
+ driving_video,
106
+ generator,
107
+ kp_detector,
108
+ relative=True,
109
+ adapt_movement_scale=True,
110
+ cpu=True,
111
+ )
112
+ imageio.mimsave("result.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps)
113
+ imageio.mimsave("driving.mp4", [img_as_ubyte(frame) for frame in driving_video], fps=fps)
114
+ cmd = f"ffmpeg -y -i result.mp4 -i {video} -c copy -map 0:0 -map 1:1 -shortest out.mp4"
115
+ subprocess.run(cmd.split())
116
+ cmd = "ffmpeg -y -i driving.mp4 -i out.mp4 -filter_complex hstack=inputs=2 final.mp4"
117
+ subprocess.run(cmd.split())
118
+ return "final.mp4"
119
+
120
+
121
+ title = "First Order Motion Model"
122
+ description = "Gradio demo for First Order Motion Model. Read more at the links below. Upload a video file (cropped to face), a facial image and have fun :D. Please note that your video will be trimmed to first 8 seconds."
123
+ article = "<p style='text-align: center'><a href='https://papers.nips.cc/paper/2019/file/31c0b36aef265d9221af80872ceb62f9-Paper.pdf' target='_blank'>First Order Motion Model for Image Animation</a> | <a href='https://github.com/AliaksandrSiarohin/first-order-model' target='_blank'>Github Repo</a></p>"
124
+ examples = [["bella_porch.mp4", "julien.png"]]
125
+ generator, kp_detector = load_checkpoints(
126
+ config_path="config/vox-256.yaml",
127
+ checkpoint_path="weights/vox-adv-cpk.pth.tar",
128
+ cpu=True,
129
+ )
130
+
131
+ iface = gr.Interface(
132
+ inference,
133
+ [
134
+ gr.inputs.Video(type="mp4"),
135
+ gr.inputs.Image(type="filepath"),
136
+ ],
137
+ outputs=gr.outputs.Video(label="Output Video"),
138
+ examples=examples,
139
+ enable_queue=True,
140
+ title=title,
141
+ article=article,
142
+ description=description,
143
+ )
144
+ iface.launch(debug=True)
vox-adv-cpk.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6792d6810d7f46e3c5c487a1cfec916b96fad8912c3c6cc81baa1fc300c820d3
3
+ size 750926934