Commit
•
ce91763
1
Parent(s):
7639b35
update
Browse files- src/utils/frame_interpolation.py +11 -4
- src/utils/util.py +9 -1
src/utils/frame_interpolation.py
CHANGED
@@ -5,6 +5,13 @@ import torch
|
|
5 |
import bisect
|
6 |
import shutil
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
def init_frame_interpolation_model():
|
9 |
print("Initializing frame interpolation model")
|
10 |
checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
|
@@ -12,7 +19,7 @@ def init_frame_interpolation_model():
|
|
12 |
model = torch.load(checkpoint_name, map_location='cpu')
|
13 |
model.eval()
|
14 |
model = model.half()
|
15 |
-
model = model.to(device=
|
16 |
return model
|
17 |
|
18 |
|
@@ -54,8 +61,8 @@ def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
|
|
54 |
|
55 |
x0 = x0.half()
|
56 |
x1 = x1.half()
|
57 |
-
x0 = x0.
|
58 |
-
x1 = x1.
|
59 |
|
60 |
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
|
61 |
|
@@ -87,4 +94,4 @@ def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
|
|
87 |
|
88 |
shutil.rmtree(image_save_dir)
|
89 |
|
90 |
-
return video_save_dir
|
|
|
5 |
import bisect
|
6 |
import shutil
|
7 |
|
8 |
+
if torch.backends.mps.is_available():
|
9 |
+
device = "mps"
|
10 |
+
elif torch.cuda.is_available():
|
11 |
+
device = "cuda"
|
12 |
+
else:
|
13 |
+
device = "cpu"
|
14 |
+
|
15 |
def init_frame_interpolation_model():
|
16 |
print("Initializing frame interpolation model")
|
17 |
checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
|
|
|
19 |
model = torch.load(checkpoint_name, map_location='cpu')
|
20 |
model.eval()
|
21 |
model = model.half()
|
22 |
+
model = model.to(device=device)
|
23 |
return model
|
24 |
|
25 |
|
|
|
61 |
|
62 |
x0 = x0.half()
|
63 |
x1 = x1.half()
|
64 |
+
x0 = x0.to(device)
|
65 |
+
x1 = x1.to(device)
|
66 |
|
67 |
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
|
68 |
|
|
|
94 |
|
95 |
shutil.rmtree(image_save_dir)
|
96 |
|
97 |
+
return video_save_dir
|
src/utils/util.py
CHANGED
@@ -12,6 +12,13 @@ import torchvision
|
|
12 |
from einops import rearrange
|
13 |
from PIL import Image
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def seed_everything(seed):
|
17 |
import random
|
@@ -19,7 +26,8 @@ def seed_everything(seed):
|
|
19 |
import numpy as np
|
20 |
|
21 |
torch.manual_seed(seed)
|
22 |
-
|
|
|
23 |
np.random.seed(seed % (2**32))
|
24 |
random.seed(seed)
|
25 |
|
|
|
12 |
from einops import rearrange
|
13 |
from PIL import Image
|
14 |
|
15 |
+
if torch.backends.mps.is_available():
|
16 |
+
device = "mps"
|
17 |
+
elif torch.cuda.is_available():
|
18 |
+
device = "cuda"
|
19 |
+
else:
|
20 |
+
device = "cpu"
|
21 |
+
|
22 |
|
23 |
def seed_everything(seed):
|
24 |
import random
|
|
|
26 |
import numpy as np
|
27 |
|
28 |
torch.manual_seed(seed)
|
29 |
+
if device == "cuda":
|
30 |
+
torch.cuda.manual_seed_all(seed)
|
31 |
np.random.seed(seed % (2**32))
|
32 |
random.seed(seed)
|
33 |
|