import os import time from datetime import datetime, timezone, timedelta from concurrent.futures import ThreadPoolExecutor import spaces import torch import torchvision.models as models import numpy as np import gradio as gr from gradio_imageslider import ImageSlider from safetensors.torch import load_file from huggingface_hub import hf_hub_download from utils import preprocess_img, preprocess_img_from_path, postprocess_img from vgg.vgg19 import VGG_19 from u2net.model import U2Net from inference import inference if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' print('DEVICE:', device) if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name()) def load_model_without_module(model, model_path): state_dict = load_file(model_path, device=device) new_state_dict = {} for k, v in state_dict.items(): name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) model = VGG_19().to(device).eval() for param in model.parameters(): param.requires_grad = False sod_model = U2Net().to(device).eval() local_model_path = hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors') load_model_without_module(sod_model, local_model_path) style_files = os.listdir('./style_images') style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files} lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist() img_size = 512 cached_style_features = {} for style_name, style_img_path in style_options.items(): style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device) with torch.no_grad(): style_features = model(style_img) cached_style_features[style_name] = style_features @spaces.GPU(duration=30) def run(content_image, style_name, style_strength=10): yield [None] * 3 content_img, original_size = preprocess_img(content_image, img_size) content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True) content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device) print('-'*15) print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est print('STYLE:', style_name) print('CONTENT IMG SIZE:', original_size) print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})') style_features = cached_style_features[style_name] st = time.time() if device == 'cuda': stream_all = torch.cuda.Stream() stream_bg = torch.cuda.Stream() def run_inference_cuda(apply_to_background, stream): with torch.cuda.stream(stream): return run_inference(apply_to_background) def run_inference(apply_to_background): return inference( model=model, sod_model=sod_model, content_image=content_img, content_image_norm=content_img_normalized, style_features=style_features, lr=lrs[style_strength-1], apply_to_background=apply_to_background ) with ThreadPoolExecutor() as executor: if device == 'cuda': future_all = executor.submit(run_inference_cuda, False, stream_all) future_bg = executor.submit(run_inference_cuda, True, stream_bg) else: future_all = executor.submit(run_inference, False) future_bg = executor.submit(run_inference, True) generated_img_all = future_all.result() generated_img_bg = future_bg.result() et = time.time() print('TIME TAKEN:', et-st) yield ( (content_image, postprocess_img(generated_img_all, original_size)), (content_image, postprocess_img(generated_img_bg, original_size)) ) def set_slider(value): return gr.update(value=value) css = """ #container { margin: 0 auto; max-width: 1200px; } """ with gr.Blocks(css=css) as demo: gr.HTML("