gokaygokay's picture
Update app.py
3d07a10 verified
raw
history blame
6.68 kB
import spaces
import os
import requests
import yaml
import torch
import gradio as gr
from PIL import Image
import sys
sys.path.append(os.path.abspath('./'))
from inference.utils import *
from core.utils import load_or_fail
from train import WurstCoreB
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
from train import WurstCore_t2i as WurstCoreC
import torch.nn.functional as F
from core.utils import load_or_fail
import numpy as np
import random
import math
from einops import rearrange
from huggingface_hub import hf_hub_download
def download_file(url, folder_path, filename):
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
print(f"File already exists: {file_path}")
else:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
print(f"File successfully downloaded and saved: {file_path}")
else:
print(f"Error downloading the file. Status code: {response.status_code}")
def download_models():
models = {
"STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"),
"STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"),
"STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"),
"STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"),
"STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"),
"ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"),
"ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"),
}
for model, (url, folder, filename) in models.items():
download_file(url, folder, filename)
download_models()
# Global variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Load configs and setup models
with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
config_c = yaml.safe_load(file)
with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
config_b = yaml.safe_load(file)
core = WurstCoreC(config_dict=config_c, device=device, training=False)
core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
extras = core.setup_extras_pre()
models = core.setup_models(extras)
models.generator.eval().requires_grad_(False)
extras_b = core_b.setup_extras_pre()
models_b = core_b.setup_models(extras_b, skip_clip=True)
models_b = WurstCoreB.Models(
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
)
models_b.generator.bfloat16().eval().requires_grad_(False)
# Load pretrained model
pretrained_path = "models/ultrapixel_t2i.safetensors"
sdd = torch.load(pretrained_path, map_location='cpu')
collect_sd = {k[7:]: v for k, v in sdd.items()}
models.train_norm.load_state_dict(collect_sd)
models.generator.eval()
models.train_norm.eval()
# Set up sampling configurations
extras.sampling_configs.update({
'cfg': 4,
'shift': 1,
'timesteps': 20,
't_start': 1.0,
'sampler': DDPMSampler(extras.gdf)
})
extras_b.sampling_configs.update({
'cfg': 1.1,
'shift': 1,
'timesteps': 10,
't_start': 1.0
})
@spaces.GPU(duration=240)
def generate_images(prompt, height, width, seed, num_images):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_size = num_images
height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
batch = {'captions': [prompt] * batch_size}
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
with torch.no_grad():
models.generator.cuda()
with torch.cuda.amp.autocast(dtype=dtype):
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
models.generator.cpu()
torch.cuda.empty_cache()
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
conditions_b['effnet'] = sampled_c
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
with torch.cuda.amp.autocast(dtype=dtype):
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
torch.cuda.empty_cache()
imgs = show_images(sampled)
return imgs
iface = gr.Interface(
fn=generate_images,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1)
],
outputs=gr.Gallery(label="Generated Images", columns=5, rows=2),
title="UltraPixel Image Generation",
description="Generate high-resolution images using UltraPixel model.",
theme='bethecloud/storj_theme'
)
iface.launch()