Spaces:
Runtime error
Runtime error
File size: 4,242 Bytes
b29a22b 1aad4f8 7c36275 0a2783f 1aad4f8 7c36275 1aad4f8 7c36275 d49a403 7c36275 d49a403 1aad4f8 0a2783f 7c36275 0a2783f 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 7c36275 0a2783f 7c36275 f51d066 0a2783f 1aad4f8 7c36275 0a2783f c8f4f3d 1aad4f8 7c36275 1aad4f8 8ee6f02 1aad4f8 7817c6c 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 9e6481f 1aad4f8 7c36275 43cc6f1 1aad4f8 7c36275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import random
import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel
import spaces
import uuid
DESCRIPTION = """# SPRIGHT T2I
#### [SPRIGHT T2I](https://spright.github.io/) is a framework to improve the spatial consistency of text-to-image models WITHOUT compromising their fidelity aspects.
"""
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
TOKEN = os.getenv("HF_TOKEN")
pipe_id = "SPRIGHT-T2I/spright-t2i-v1"
unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet_ema", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
pipe_id,
unet=unet,
torch_dtype=torch.float16,
use_safetensors=True,
token=TOKEN,
).to("cuda")
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.gpu
def generate(
prompt: str,
seed: int = 0,
width: int = 768,
height: int = 768,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
seed = randomize_seed_fn(seed, randomize_seed)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
image_path = save_image(image)
print(image_path)
return [image_path], seed
examples = [
"A cat next to a suitcase",
"A candle on the left of a mouse",
"A bag on the right of a dog",
"A mouse on the top of a bowl",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", columns=1, show_label=False)
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row(visible=False):
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=10,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result, seed],
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate,
inputs=[prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed],
outputs=[result, seed],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|