|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import bisect |
|
import shutil |
|
|
|
def init_frame_interpolation_model(): |
|
print("Initializing frame interpolation model") |
|
checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt") |
|
|
|
model = torch.load(checkpoint_name, map_location='cpu') |
|
model.eval() |
|
model = model.half() |
|
model = model.to(device="cuda") |
|
return model |
|
|
|
|
|
def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1): |
|
|
|
image_save_dir = input_file + '_tmp' |
|
os.makedirs(image_save_dir, exist_ok=True) |
|
|
|
input_img_list = os.listdir(input_file) |
|
input_img_list.sort() |
|
|
|
for idx in range(len(input_img_list)-1): |
|
img1 = cv2.imread(os.path.join(input_file, input_img_list[idx])) |
|
img2 = cv2.imread(os.path.join(input_file, input_img_list[idx+1])) |
|
|
|
image1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) |
|
image2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) |
|
image1 = torch.from_numpy(image1).unsqueeze(0).permute(0, 3, 1, 2) |
|
image2 = torch.from_numpy(image2).unsqueeze(0).permute(0, 3, 1, 2) |
|
|
|
results = [image1, image2] |
|
|
|
inter_frames = int(inter_frames) |
|
idxes = [0, inter_frames + 1] |
|
remains = list(range(1, inter_frames + 1)) |
|
|
|
splits = torch.linspace(0, 1, inter_frames + 2) |
|
|
|
for _ in range(len(remains)): |
|
starts = splits[idxes[:-1]] |
|
ends = splits[idxes[1:]] |
|
distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() |
|
matrix = torch.argmin(distances).item() |
|
start_i, step = np.unravel_index(matrix, distances.shape) |
|
end_i = start_i + 1 |
|
|
|
x0 = results[start_i] |
|
x1 = results[end_i] |
|
|
|
x0 = x0.half() |
|
x1 = x1.half() |
|
x0 = x0.cuda() |
|
x1 = x1.cuda() |
|
|
|
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) |
|
|
|
with torch.no_grad(): |
|
prediction = model(x0, x1, dt) |
|
insert_position = bisect.bisect_left(idxes, remains[step]) |
|
idxes.insert(insert_position, remains[step]) |
|
results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) |
|
del remains[step] |
|
|
|
frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results] |
|
|
|
for sub_idx in range(len(frames)): |
|
img_path = os.path.join(image_save_dir, f'{sub_idx+idx*(inter_frames+1):06d}.png') |
|
cv2.imwrite(img_path, frames[sub_idx]) |
|
|
|
final_frames = [] |
|
final_img_list = os.listdir(image_save_dir) |
|
final_img_list.sort() |
|
for item in final_img_list: |
|
final_frames.append(cv2.imread(os.path.join(image_save_dir, item))) |
|
w, h = final_frames[0].shape[1::-1] |
|
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') |
|
video_save_dir = input_file + '.mp4' |
|
writer = cv2.VideoWriter(video_save_dir, fourcc, fps, (w, h)) |
|
for frame in final_frames: |
|
writer.write(frame) |
|
writer.release() |
|
|
|
shutil.rmtree(image_save_dir) |
|
|
|
return video_save_dir |