import os import time from datetime import datetime, timezone, timedelta from concurrent.futures import ThreadPoolExecutor import spaces import torch import torchvision.models as models import numpy as np import gradio as gr from gradio_imageslider import ImageSlider from utils import preprocess_img, preprocess_img_from_path, postprocess_img from vgg.vgg19 import VGG_19 from inference import inference if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' print('DEVICE:', device) if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name()) model = VGG_19().to(device).eval() for param in model.parameters(): param.requires_grad = False segmentation_model = models.segmentation.deeplabv3_resnet101( weights='DEFAULT' ).to(device).eval() style_files = os.listdir('./style_images') style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files} lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist() img_size = 512 cached_style_features = {} for style_name, style_img_path in style_options.items(): style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device) with torch.no_grad(): style_features = model(style_img) cached_style_features[style_name] = style_features @spaces.GPU(duration=12) def run(content_image, style_name, style_strength=5): yield [None] * 3 content_img, original_size = preprocess_img(content_image, img_size) content_img = content_img.to(device) print('-'*15) print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est print('STYLE:', style_name) print('CONTENT IMG SIZE:', original_size) print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})') style_features = cached_style_features[style_name] st = time.time() if device == 'cuda': stream_all = torch.cuda.Stream() stream_bg = torch.cuda.Stream() def run_inference_cuda(apply_to_background, stream): with torch.cuda.stream(stream): return run_inference(apply_to_background) def run_inference(apply_to_background): return inference( model=model, segmentation_model=segmentation_model, content_image=content_img, style_features=style_features, lr=lrs[style_strength-1], apply_to_background=apply_to_background ) with ThreadPoolExecutor() as executor: if device == 'cuda': future_all = executor.submit(run_inference_cuda, False, stream_all) future_bg = executor.submit(run_inference_cuda, True, stream_bg) else: future_all = executor.submit(run_inference, False) future_bg = executor.submit(run_inference, True) generated_img_all, _ = future_all.result() generated_img_bg, bg_ratio = future_bg.result() et = time.time() print('TIME TAKEN:', et-st) yield ( (content_image, postprocess_img(generated_img_all, original_size)), (content_image, postprocess_img(generated_img_bg, original_size)), f'{bg_ratio:.2f}' ) def set_slider(value): return gr.update(value=value) css = """ #container { margin: 0 auto; max-width: 1100px; } """ with gr.Blocks(css=css) as demo: gr.HTML("