Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import requests | |
import yaml | |
import torch | |
import gradio as gr | |
from PIL import Image | |
import sys | |
sys.path.append(os.path.abspath('./')) | |
from inference.utils import * | |
from core.utils import load_or_fail | |
from train import WurstCoreB | |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight | |
from train import WurstCore_t2i as WurstCoreC | |
import torch.nn.functional as F | |
from core.utils import load_or_fail | |
import numpy as np | |
import random | |
import math | |
from einops import rearrange | |
from huggingface_hub import hf_hub_download | |
def download_file(url, folder_path, filename): | |
if not os.path.exists(folder_path): | |
os.makedirs(folder_path) | |
file_path = os.path.join(folder_path, filename) | |
if os.path.isfile(file_path): | |
print(f"File already exists: {file_path}") | |
else: | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
with open(file_path, 'wb') as file: | |
for chunk in response.iter_content(chunk_size=1024): | |
file.write(chunk) | |
print(f"File successfully downloaded and saved: {file_path}") | |
else: | |
print(f"Error downloading the file. Status code: {response.status_code}") | |
def download_models(): | |
models = { | |
"STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"), | |
"STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"), | |
"STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"), | |
"STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"), | |
"STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"), | |
"ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"), | |
"ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"), | |
} | |
for model, (url, folder, filename) in models.items(): | |
download_file(url, folder, filename) | |
download_models() | |
# Global variables | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.bfloat16 | |
# Load configs and setup models | |
with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file: | |
config_c = yaml.safe_load(file) | |
with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file: | |
config_b = yaml.safe_load(file) | |
core = WurstCoreC(config_dict=config_c, device=device, training=False) | |
core_b = WurstCoreB(config_dict=config_b, device=device, training=False) | |
extras = core.setup_extras_pre() | |
models = core.setup_models(extras) | |
models.generator.eval().requires_grad_(False) | |
extras_b = core_b.setup_extras_pre() | |
models_b = core_b.setup_models(extras_b, skip_clip=True) | |
models_b = WurstCoreB.Models( | |
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} | |
) | |
models_b.generator.bfloat16().eval().requires_grad_(False) | |
# Load pretrained model | |
pretrained_path = "models/ultrapixel_t2i.safetensors" | |
sdd = torch.load(pretrained_path, map_location='cpu') | |
collect_sd = {k[7:]: v for k, v in sdd.items()} | |
models.train_norm.load_state_dict(collect_sd) | |
models.generator.eval() | |
models.train_norm.eval() | |
# Set up sampling configurations | |
extras.sampling_configs.update({ | |
'cfg': 4, | |
'shift': 1, | |
'timesteps': 20, | |
't_start': 1.0, | |
'sampler': DDPMSampler(extras.gdf) | |
}) | |
extras_b.sampling_configs.update({ | |
'cfg': 1.1, | |
'shift': 1, | |
'timesteps': 10, | |
't_start': 1.0 | |
}) | |
def generate_image(prompt, height, width, seed): | |
torch.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
batch_size = 1 | |
height_lr, width_lr = get_target_lr_size(height / width, std_size=32) | |
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) | |
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) | |
batch = {'captions': [prompt] * batch_size} | |
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) | |
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) | |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) | |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) | |
with torch.no_grad(): | |
models.generator.cuda() | |
with torch.cuda.amp.autocast(dtype=dtype): | |
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) | |
models.generator.cpu() | |
torch.cuda.empty_cache() | |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) | |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) | |
conditions_b['effnet'] = sampled_c | |
unconditions_b['effnet'] = torch.zeros_like(sampled_c) | |
with torch.cuda.amp.autocast(dtype=dtype): | |
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True) | |
torch.cuda.empty_cache() | |
imgs = show_images(sampled) | |
return imgs[0] | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024), | |
gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024), | |
gr.Number(label="Seed", value=42) | |
], | |
outputs=gr.Image(type="pil"), | |
title="UltraPixel Image Generation", | |
description="Generate high-resolution images using UltraPixel model.", | |
theme='bethecloud/storj_theme' | |
) | |
iface.launch() |