import gradio as gr import PIL import cv2 import numpy as np from src.deoldify import device from src.deoldify.device_id import DeviceId from src.deoldify.visualize import * from src.app_utils import get_model_bin device.set(device=DeviceId.CPU) def load_model(model_dir, option): if option.lower() == 'artistic': model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) colorizer = get_image_colorizer(artistic=True) elif option.lower() == 'stable': model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) colorizer = get_image_colorizer(artistic=False) return colorizer def resize_img(input_img, max_size): img = input_img.copy() img_height, img_width = img.shape[0], img.shape[1] if max(img_height, img_width) > max_size: if img_height > img_width: new_width = img_width * (max_size / img_height) new_height = max_size resized_img = cv2.resize(img, (int(new_width), int(new_height))) return resized_img elif img_height <= img_width: new_width = img_height * (max_size / img_width) new_height = max_size resized_img = cv2.resize(img, (int(new_width), int(new_height))) return resized_img return img def colorize_image(input_image, colorizer, img_size=800): pil_img = input_image.convert("RGB") img_rgb = np.array(pil_img) resized_img_rgb = resize_img(img_rgb, img_size) resized_pil_img = PIL.Image.fromarray(resized_img_rgb) output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) return output_pil_img def app(input_image, model='stable'): # Load models colorizer = load_model('models/', model) # Colorize the image output_image = colorize_image(input_image, colorizer) return output_image title = "Aiconvert.online" gr.Interface( app, inputs=[gr.Image(type="pil", label="Input"), gr.Dropdown(["Artistic", "Stable"], label="Model")], outputs=gr.Image(type="pil", label="Output", show_share_button=False), title=title, css="footer{display:none !important;}", theme=gr.themes.Base(), enable_queue=True, allow_flagging=False ).launch()