|
import torch |
|
from diffusers.image_processor import VaeImageProcessor |
|
from torch.nn import functional as F |
|
import cv2 |
|
from util.utils import * |
|
from util.rife.pytorch_msssim import ssim_matlab |
|
import numpy as np |
|
import logging |
|
import skvideo.io |
|
from util.rife.RIFE_HDv3 import Model |
|
|
|
logger = logging.getLogger(__name__) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def pad_image(img, scale): |
|
_, _, h, w = img.shape |
|
tmp = max(32, int(32 / scale)) |
|
ph = ((h - 1) // tmp + 1) * tmp |
|
pw = ((w - 1) // tmp + 1) * tmp |
|
padding = (0, 0, pw - w, ph - h) |
|
return F.pad(img, padding) |
|
|
|
|
|
def make_inference(model, I0, I1, upscale_amount, n): |
|
middle = model.inference(I0, I1, upscale_amount) |
|
if n == 1: |
|
return [middle] |
|
first_half = make_inference(model, I0, middle, upscale_amount, n=n // 2) |
|
second_half = make_inference(model, middle, I1, upscale_amount, n=n // 2) |
|
if n % 2: |
|
return [*first_half, middle, *second_half] |
|
else: |
|
return [*first_half, *second_half] |
|
|
|
|
|
@torch.inference_mode() |
|
def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"): |
|
print(f"samples dtype:{samples.dtype}") |
|
print(f"samples shape:{samples.shape}") |
|
output = [] |
|
pbar = ProgressBar(samples.shape[0], desc="frame interpolating") |
|
|
|
for b in range(samples.shape[0]): |
|
frame = samples[b : b + 1] |
|
_, _, h, w = frame.shape |
|
|
|
I0 = samples[b : b + 1] |
|
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:] |
|
|
|
I0 = I0.to(torch.float) |
|
I1 = I1.to(torch.float) |
|
|
|
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) |
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) |
|
|
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) |
|
|
|
if ssim > 0.996: |
|
I1 = make_inference(model, I0, I1, upscale_amount, 1) |
|
I1 = I1[0] |
|
frame = I1 |
|
|
|
tmp_output = [] |
|
if ssim < 0.2: |
|
for _ in range((2**exp) - 1): |
|
tmp_output.append(I0) |
|
else: |
|
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else [] |
|
|
|
frame = F.interpolate(frame, size=(h, w)) |
|
output.append(frame.to(output_device)) |
|
|
|
for tmp_frame in tmp_output: |
|
tmp_frame = F.interpolate(tmp_frame, size=(h, w)) |
|
output.append(tmp_frame.to(output_device)) |
|
|
|
pbar.update(1) |
|
|
|
return output |
|
|
|
|
|
def load_rife_model(model_path): |
|
model = Model() |
|
model.load_model(model_path, -1) |
|
model.eval() |
|
return model |
|
|
|
|
|
|
|
def frame_generator(video_capture): |
|
while True: |
|
ret, frame = video_capture.read() |
|
if not ret: |
|
break |
|
yield frame |
|
video_capture.release() |
|
|
|
|
|
def rife_inference_with_path(model, video_path): |
|
video_capture = cv2.VideoCapture(video_path) |
|
tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) |
|
pt_frame_data = [] |
|
pt_frame = skvideo.io.vreader(video_path) |
|
for frame in pt_frame: |
|
pt_frame_data.append( |
|
torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0 |
|
) |
|
|
|
pt_frame = torch.from_numpy(np.stack(pt_frame_data)) |
|
pt_frame = pt_frame.to(device) |
|
pbar = ProgressBar(tot_frame, desc="RIFE inference") |
|
frames = ssim_interpolation_rife(model, pt_frame) |
|
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) |
|
image_np = VaeImageProcessor.pt_to_numpy(pt_image) |
|
image_pil = VaeImageProcessor.numpy_to_pil(image_np) |
|
video_path = save_video(image_pil, fps=16) |
|
if pbar: |
|
pbar.update(1) |
|
return video_path |
|
|
|
|
|
def rife_inference_with_latents(model, latents): |
|
rife_results = [] |
|
latents = latents.to(device) |
|
for i in range(latents.size(0)): |
|
|
|
latent = latents[i] |
|
frames = ssim_interpolation_rife(model, latent) |
|
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) |
|
rife_results.append(pt_image) |
|
|
|
return torch.stack(rife_results) |
|
|