|
import os |
|
import gradio as gr |
|
import cv2 |
|
import torch |
|
import urllib.request |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
|
|
def calculate_depth(model_type, img): |
|
|
|
if not os.path.exists('temp'): |
|
os.system('mkdir temp') |
|
|
|
filename = "temp/image.jpg" |
|
|
|
img.save(filename, "JPEG") |
|
|
|
midas = torch.hub.load("intel-isl/MiDaS", model_type) |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
midas.to(device) |
|
midas.eval() |
|
|
|
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") |
|
|
|
if model_type == "DPT_Large" or model_type == "DPT_Hybrid": |
|
transform = midas_transforms.dpt_transform |
|
else: |
|
transform = midas_transforms.small_transform |
|
|
|
img = cv2.imread(filename) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
input_batch = transform(img).to(device) |
|
|
|
with torch.no_grad(): |
|
prediction = midas(input_batch) |
|
|
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=img.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
).squeeze() |
|
|
|
output = prediction.cpu().numpy() |
|
|
|
formatted = (output * 255 / np.max(output)).astype('uint8') |
|
out_im = Image.fromarray(formatted) |
|
out_im.save("temp/image_depth.jpeg", "JPEG") |
|
|
|
return f'temp/image_depth.jpeg' |
|
|
|
def wiggle_effect(slider): |
|
|
|
return [f'temp/image_depth.jpeg',f'temp/image_depth.jpeg'] |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Start typing below and then click **Run** to see the output.") |
|
|
|
|
|
|
|
midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"] |
|
inp = [gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type")] |
|
with gr.Row(): |
|
inp.append(gr.Image(type="pil", label="Input")) |
|
out = gr.Image(type="file", label="depth_estimation") |
|
btn = gr.Button("Calculate depth") |
|
btn.click(fn=calculate_depth, inputs=inp, outputs=out) |
|
|
|
|
|
|
|
inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)] |
|
with gr.Row(): |
|
out = [ gr.Image(type="file", label="Output_images"), |
|
gr.Image(type="file", label="Output_wiggle")] |
|
btn = gr.Button("Generate Wigglegram") |
|
btn.click(fn=wiggle_effect, inputs=inp, outputs=out) |
|
|
|
demo.launch() |