Spaces:
Runtime error
Runtime error
File size: 4,361 Bytes
b29a22b 1aad4f8 a962a1b aa0ffd0 0a2783f 1aad4f8 7c36275 0bbb523 7c36275 1aad4f8 6b22806 1aad4f8 7c36275 3d03a32 4104e3d d49a403 a962a1b 7c36275 4104e3d 7c36275 6b22806 d49a403 1aad4f8 0a2783f 7c36275 0a2783f 1aad4f8 7c36275 1aad4f8 7c36275 aa0ffd0 1aad4f8 7c36275 0a2783f 7c36275 f51d066 0a2783f 1aad4f8 7c36275 0a2783f c8f4f3d 1aad4f8 7c36275 1aad4f8 8ee6f02 1aad4f8 3d03a32 7817c6c 4104e3d 7817c6c 4104e3d 7817c6c 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 7c36275 1aad4f8 9e6481f 1aad4f8 7c36275 43cc6f1 1aad4f8 7c36275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import random
import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline
#import spaces
import uuid
DESCRIPTION = """# SPRIGHT T2I
[SPRIGHT T2I](https://spright-t2i.github.io/) is a framework to improve the spatial consistency of text-to-image models WITHOUT compromising their fidelity aspects.
"""
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
DEFAULT_IMAGE_SIZE = 1024
torch_dtype = torch.float16
if device == "cpu" or device == "mps":
DEFAULT_IMAGE_SIZE = 512
torch_dtype = torch.float32
pipe_id = "SPRIGHT-T2I/spright-t2i-sd2"
pipe = DiffusionPipeline.from_pretrained(
pipe_id,
torch_dtype=torch_dtype,
use_safetensors=True,
).to(device)
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
#@spaces.GPU
def generate(
prompt: str,
seed: int = 0,
width: int = 768,
height: int = 768,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
seed = randomize_seed_fn(seed, randomize_seed)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
image_path = save_image(image)
print(image_path)
return [image_path], seed
examples = [
"A cat next to a suitcase",
"A candle on the left of a mouse",
"A bag on the right of a dog",
"A mouse on the top of a bowl",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", columns=1, show_label=False)
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_IMAGE_SIZE,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_IMAGE_SIZE,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=10,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result, seed],
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate,
inputs=[prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed],
outputs=[result, seed],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|