import spaces import os import requests import yaml import torch import gradio as gr from PIL import Image import sys sys.path.append(os.path.abspath('./')) from inference.utils import * from core.utils import load_or_fail from train import WurstCoreB from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight from train import WurstCore_t2i as WurstCoreC import torch.nn.functional as F from core.utils import load_or_fail import numpy as np import random import math from einops import rearrange from huggingface_hub import hf_hub_download def download_file(url, folder_path, filename): if not os.path.exists(folder_path): os.makedirs(folder_path) file_path = os.path.join(folder_path, filename) if os.path.isfile(file_path): print(f"File already exists: {file_path}") else: response = requests.get(url, stream=True) if response.status_code == 200: with open(file_path, 'wb') as file: for chunk in response.iter_content(chunk_size=1024): file.write(chunk) print(f"File successfully downloaded and saved: {file_path}") else: print(f"Error downloading the file. Status code: {response.status_code}") def download_models(): models = { "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"), "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"), "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"), "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"), "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"), "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"), "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"), } for model, (url, folder, filename) in models.items(): download_file(url, folder, filename) download_models() # Global variables device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 # Load configs and setup models with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file: config_c = yaml.safe_load(file) with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file: config_b = yaml.safe_load(file) core = WurstCoreC(config_dict=config_c, device=device, training=False) core_b = WurstCoreB(config_dict=config_b, device=device, training=False) extras = core.setup_extras_pre() models = core.setup_models(extras) models.generator.eval().requires_grad_(False) extras_b = core_b.setup_extras_pre() models_b = core_b.setup_models(extras_b, skip_clip=True) models_b = WurstCoreB.Models( **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} ) models_b.generator.bfloat16().eval().requires_grad_(False) # Load pretrained model pretrained_path = "models/ultrapixel_t2i.safetensors" sdd = torch.load(pretrained_path, map_location='cpu') collect_sd = {k[7:]: v for k, v in sdd.items()} models.train_norm.load_state_dict(collect_sd) models.generator.eval() models.train_norm.eval() # Set up sampling configurations extras.sampling_configs.update({ 'cfg': 4, 'shift': 1, 'timesteps': 20, 't_start': 1.0, 'sampler': DDPMSampler(extras.gdf) }) extras_b.sampling_configs.update({ 'cfg': 1.1, 'shift': 1, 'timesteps': 10, 't_start': 1.0 }) @spaces.GPU(duration=240) def generate_images(prompt, height, width, seed, num_images): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) batch_size = num_images height_lr, width_lr = get_target_lr_size(height / width, std_size=32) stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) batch = {'captions': [prompt] * batch_size} conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) with torch.no_grad(): models.generator.cuda() with torch.cuda.amp.autocast(dtype=dtype): sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) models.generator.cpu() torch.cuda.empty_cache() conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) conditions_b['effnet'] = sampled_c unconditions_b['effnet'] = torch.zeros_like(sampled_c) with torch.cuda.amp.autocast(dtype=dtype): sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True) torch.cuda.empty_cache() imgs = show_images(sampled) return imgs iface = gr.Interface( fn=generate_images, inputs=[ gr.Textbox(label="Prompt"), gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024), gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024), gr.Number(label="Seed", value=42), gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1) ], outputs=gr.Gallery(label="Generated Images", columns=5, rows=2), title="UltraPixel Image Generation", description="Generate high-resolution images using UltraPixel model.", theme='bethecloud/storj_theme', examples=[ ["A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening in the early morning light.", 1024, 1024, 42, 1], ["A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a clear blue sky.", 1024, 1024, 42, 1], ["A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing a vintage floral dress and standing in front of a blooming garden.", 1024, 1024, 42, 1], ["The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", 1024, 1024, 42, 1] ], cache_examples=True ) iface.launch()