gokaygokay's picture
Update app.py
fd2a786 verified
raw
history blame
7.57 kB
import spaces
import os
import requests
import yaml
import torch
import gradio as gr
from PIL import Image
import sys
sys.path.append(os.path.abspath('./'))
from inference.utils import *
from core.utils import load_or_fail
from train import WurstCoreB
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
from train import WurstCore_t2i as WurstCoreC
import torch.nn.functional as F
from core.utils import load_or_fail
import numpy as np
import random
import math
from einops import rearrange
from huggingface_hub import hf_hub_download
def download_file(url, folder_path, filename):
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
print(f"File already exists: {file_path}")
else:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
print(f"File successfully downloaded and saved: {file_path}")
else:
print(f"Error downloading the file. Status code: {response.status_code}")
def download_models():
models = {
"STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"),
"STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"),
"STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"),
"STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"),
"STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"),
"ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"),
"ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"),
}
for model, (url, folder, filename) in models.items():
download_file(url, folder, filename)
download_models()
# Global variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Load configs and setup models
with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
config_c = yaml.safe_load(file)
with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
config_b = yaml.safe_load(file)
core = WurstCoreC(config_dict=config_c, device=device, training=False)
core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
extras = core.setup_extras_pre()
models = core.setup_models(extras)
models.generator.eval().requires_grad_(False)
extras_b = core_b.setup_extras_pre()
models_b = core_b.setup_models(extras_b, skip_clip=True)
models_b = WurstCoreB.Models(
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
)
models_b.generator.bfloat16().eval().requires_grad_(False)
# Load pretrained model
pretrained_path = "models/ultrapixel_t2i.safetensors"
sdd = torch.load(pretrained_path, map_location='cpu')
collect_sd = {k[7:]: v for k, v in sdd.items()}
models.train_norm.load_state_dict(collect_sd)
models.generator.eval()
models.train_norm.eval()
# Set up sampling configurations
extras.sampling_configs.update({
'cfg': 4,
'shift': 1,
'timesteps': 20,
't_start': 1.0,
'sampler': DDPMSampler(extras.gdf)
})
extras_b.sampling_configs.update({
'cfg': 1.1,
'shift': 1,
'timesteps': 10,
't_start': 1.0
})
@spaces.GPU(duration=240)
def generate_images(prompt, height, width, seed, num_images):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_size = num_images
height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
batch = {'captions': [prompt] * batch_size}
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
with torch.no_grad():
models.generator.cuda()
with torch.cuda.amp.autocast(dtype=dtype):
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
models.generator.cpu()
torch.cuda.empty_cache()
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
conditions_b['effnet'] = sampled_c
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
with torch.cuda.amp.autocast(dtype=dtype):
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
torch.cuda.empty_cache()
imgs = show_images(sampled)
return imgs
iface = gr.Interface(
fn=generate_images,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1)
],
outputs=gr.Gallery(label="Generated Images", columns=5, rows=2),
title="UltraPixel Image Generation",
description="Generate high-resolution images using UltraPixel model.",
theme='bethecloud/storj_theme',
examples=[
["A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening in the early morning light.", 1024, 1024, 42, 1],
["A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a clear blue sky.", 1024, 1024, 42, 1],
["A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing a vintage floral dress and standing in front of a blooming garden.", 1024, 1024, 42, 1],
["The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", 1024, 1024, 42, 1]
],
cache_examples=True
)
iface.launch()