Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,569 Bytes
78c280f 7c0f531 2f4febc 7c0f531 2f4febc 7c0f531 2f4febc 15d65d6 2f4febc dfca54f 2f4febc 3d07a10 2f4febc 3d07a10 2f4febc 7c0f531 2f4febc 7c0f531 2f4febc 3d07a10 2f4febc 3d07a10 2f4febc 3d07a10 7c0f531 3d07a10 2f4febc c231df5 fd2a786 7c0f531 2f4febc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import spaces
import os
import requests
import yaml
import torch
import gradio as gr
from PIL import Image
import sys
sys.path.append(os.path.abspath('./'))
from inference.utils import *
from core.utils import load_or_fail
from train import WurstCoreB
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
from train import WurstCore_t2i as WurstCoreC
import torch.nn.functional as F
from core.utils import load_or_fail
import numpy as np
import random
import math
from einops import rearrange
from huggingface_hub import hf_hub_download
def download_file(url, folder_path, filename):
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
print(f"File already exists: {file_path}")
else:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
print(f"File successfully downloaded and saved: {file_path}")
else:
print(f"Error downloading the file. Status code: {response.status_code}")
def download_models():
models = {
"STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"),
"STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"),
"STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"),
"STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"),
"STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"),
"ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"),
"ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"),
}
for model, (url, folder, filename) in models.items():
download_file(url, folder, filename)
download_models()
# Global variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Load configs and setup models
with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
config_c = yaml.safe_load(file)
with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
config_b = yaml.safe_load(file)
core = WurstCoreC(config_dict=config_c, device=device, training=False)
core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
extras = core.setup_extras_pre()
models = core.setup_models(extras)
models.generator.eval().requires_grad_(False)
extras_b = core_b.setup_extras_pre()
models_b = core_b.setup_models(extras_b, skip_clip=True)
models_b = WurstCoreB.Models(
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
)
models_b.generator.bfloat16().eval().requires_grad_(False)
# Load pretrained model
pretrained_path = "models/ultrapixel_t2i.safetensors"
sdd = torch.load(pretrained_path, map_location='cpu')
collect_sd = {k[7:]: v for k, v in sdd.items()}
models.train_norm.load_state_dict(collect_sd)
models.generator.eval()
models.train_norm.eval()
# Set up sampling configurations
extras.sampling_configs.update({
'cfg': 4,
'shift': 1,
'timesteps': 20,
't_start': 1.0,
'sampler': DDPMSampler(extras.gdf)
})
extras_b.sampling_configs.update({
'cfg': 1.1,
'shift': 1,
'timesteps': 10,
't_start': 1.0
})
@spaces.GPU(duration=240)
def generate_images(prompt, height, width, seed, num_images):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_size = num_images
height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
batch = {'captions': [prompt] * batch_size}
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
with torch.no_grad():
models.generator.cuda()
with torch.cuda.amp.autocast(dtype=dtype):
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
models.generator.cpu()
torch.cuda.empty_cache()
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
conditions_b['effnet'] = sampled_c
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
with torch.cuda.amp.autocast(dtype=dtype):
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
torch.cuda.empty_cache()
imgs = show_images(sampled)
return imgs
iface = gr.Interface(
fn=generate_images,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1)
],
outputs=gr.Gallery(label="Generated Images", columns=5, rows=2),
title="UltraPixel Image Generation",
description="Generate high-resolution images using UltraPixel model.",
theme='bethecloud/storj_theme',
examples=[
["A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening in the early morning light.", 1024, 1024, 42, 1],
["A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a clear blue sky.", 1024, 1024, 42, 1],
["A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing a vintage floral dress and standing in front of a blooming garden.", 1024, 1024, 42, 1],
["The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", 1024, 1024, 42, 1]
],
cache_examples=True
)
iface.launch() |