abc / app.py
CHEN11102's picture
Update app.py
38d4a7e verified
raw
history blame
2.57 kB
import os
import sys
import numpy as np
import tensorflow as tf
import mediapy
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
# Clone the repository and add the path
os.system("git clone https://github.com/google-research/frame-interpolation")
sys.path.append("frame-interpolation")
# Import after appending the path
from eval import interpolator, util
def load_model(model_name):
model = interpolator.Interpolator(snapshot_download(repo_id=model_name), None)
return model
model_names = [
"akhaliq/frame-interpolation-film-style",
"CHEN11102/sportmodel",
"CHEN11102/sportModel2",
]
models = {model_name: load_model(model_name) for model_name in model_names}
ffmpeg_path = util.get_ffmpeg_path()
mediapy.set_ffmpeg(ffmpeg_path)
def resize(width, img):
img = Image.fromarray(img)
wpercent = (width / float(img.size[0]))
hsize = int((float(img.size[1]) * float(wpercent)))
img = img.resize((width, hsize), Image.LANCZOS)
return img
def resize_and_crop(img_path, size, crop_origin="middle"):
img = Image.open(img_path)
img = img.resize(size, Image.LANCZOS)
return img
def resize_img(img1, img2_path):
img_target_size = Image.open(img1)
img_to_resize = resize_and_crop(
img2_path,
(img_target_size.size[0], img_target_size.size[1]), # set width and height to match img1
crop_origin="middle"
)
img_to_resize.save('resized_img2.png')
def predict(frame1, frame2, times_to_interpolate, model_name):
model = models[model_name]
frame1 = resize(1080, frame1)
frame2 = resize(1080, frame2)
frame1.save("test1.png")
frame2.save("test2.png")
resize_img("test1.png", "test2.png")
input_frames = ["test1.png", "resized_img2.png"]
frames = list(
util.interpolate_recursively_from_files(
input_frames, times_to_interpolate, model))
mediapy.write_video("out.mp4", frames, fps=30)
return "out.mp4"
title = "Sports model"
description = "Wechat:Liesle1"
article = ""
examples = [
['cat3.jpeg', 'cat4.jpeg', 2, model_names[0]],
['cat1.jpeg', 'cat2.jpeg', 2, model_names[1]],
]
gr.Interface(
fn=predict,
inputs=[
gr.Image(label="First Frame"),
gr.Image(label="Second Frame"),
gr.Number(label="Times to Interpolate", value=2),
gr.Dropdown(label="Model", choices=model_names),
],
outputs=gr.Video(label="Interpolated Frames"),
title=title,
description=description,
article=article,
examples=examples,
).launch()