gokaygokay's picture
full_files
2f4febc
raw
history blame
6.6 kB
import spaces
import os
import requests
import yaml
import torch
import gradio as gr
from PIL import Image
import sys
sys.path.append(os.path.abspath('./'))
from inference.utils import *
from core.utils import load_or_fail
from train import WurstCoreB
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
from train import WurstCore_t2i as WurstCoreC
import torch.nn.functional as F
from core.utils import load_or_fail
import numpy as np
import random
import math
from einops import rearrange
from huggingface_hub import hf_hub_download
def download_file(url, folder_path, filename):
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
print(f"File already exists: {file_path}")
else:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
print(f"File successfully downloaded and saved: {file_path}")
else:
print(f"Error downloading the file. Status code: {response.status_code}")
def download_models():
models = {
"STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"),
"STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"),
"STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"),
"STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"),
"STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"),
"ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"),
"ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"),
}
for model, (url, folder, filename) in models.items():
download_file(url, folder, filename)
download_models()
# Global variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Load configs and setup models
with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
config_c = yaml.safe_load(file)
with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
config_b = yaml.safe_load(file)
core = WurstCoreC(config_dict=config_c, device=device, training=False)
core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
extras = core.setup_extras_pre()
models = core.setup_models(extras)
models.generator.eval().requires_grad_(False)
extras_b = core_b.setup_extras_pre()
models_b = core_b.setup_models(extras_b, skip_clip=True)
models_b = WurstCoreB.Models(
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
)
models_b.generator.bfloat16().eval().requires_grad_(False)
# Load pretrained model
pretrained_path = "models/ultrapixel_t2i.safetensors"
sdd = torch.load(pretrained_path, map_location='cpu')
collect_sd = {k[7:]: v for k, v in sdd.items()}
models.train_norm.load_state_dict(collect_sd)
models.generator.eval()
models.train_norm.eval()
# Set up sampling configurations
extras.sampling_configs.update({
'cfg': 4,
'shift': 1,
'timesteps': 20,
't_start': 1.0,
'sampler': DDPMSampler(extras.gdf)
})
extras_b.sampling_configs.update({
'cfg': 1.1,
'shift': 1,
'timesteps': 10,
't_start': 1.0
})
@spaces.GPU
def generate_image(prompt, height, width, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_size = 1
height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
batch = {'captions': [prompt] * batch_size}
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
with torch.no_grad():
models.generator.cuda()
with torch.cuda.amp.autocast(dtype=dtype):
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
models.generator.cpu()
torch.cuda.empty_cache()
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
conditions_b['effnet'] = sampled_c
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
with torch.cuda.amp.autocast(dtype=dtype):
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
torch.cuda.empty_cache()
imgs = show_images(sampled)
return imgs[0]
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
gr.Number(label="Seed", value=42)
],
outputs=gr.Image(type="pil"),
title="UltraPixel Image Generation",
description="Generate high-resolution images using UltraPixel model.",
theme='bethecloud/storj_theme'
)
iface.launch()