File size: 4,242 Bytes
b29a22b
1aad4f8
 
 
 
 
7c36275
 
0a2783f
1aad4f8
7c36275
 
 
1aad4f8
 
7c36275
 
 
d49a403
 
7c36275
 
 
 
 
 
 
 
 
d49a403
1aad4f8
0a2783f
7c36275
0a2783f
 
1aad4f8
7c36275
1aad4f8
 
 
 
 
7c36275
 
1aad4f8
 
 
7c36275
 
 
 
0a2783f
7c36275
f51d066
0a2783f
1aad4f8
 
7c36275
 
 
 
 
 
 
 
 
0a2783f
c8f4f3d
 
1aad4f8
 
7c36275
 
 
 
 
 
1aad4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee6f02
1aad4f8
 
 
 
 
 
 
 
 
7817c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aad4f8
7c36275
 
1aad4f8
 
 
7c36275
1aad4f8
7c36275
 
1aad4f8
 
 
7c36275
1aad4f8
 
 
 
 
9e6481f
1aad4f8
 
 
 
 
 
 
 
 
 
7c36275
43cc6f1
1aad4f8
 
 
 
7c36275
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import random

import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel
import spaces
import uuid

DESCRIPTION = """# SPRIGHT T2I
#### [SPRIGHT T2I](https://spright.github.io/) is a framework to improve the spatial consistency of text-to-image models WITHOUT compromising their fidelity aspects.
"""

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
TOKEN = os.getenv("HF_TOKEN")


pipe_id = "SPRIGHT-T2I/spright-t2i-v1"
unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet_ema", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
    pipe_id,
    unet=unet,
    torch_dtype=torch.float16,
    use_safetensors=True,
    token=TOKEN,
).to("cuda")


def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name


def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


@spaces.gpu
def generate(
    prompt: str,
    seed: int = 0,
    width: int = 768,
    height: int = 768,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    randomize_seed: bool = False,
    progress=gr.Progress(track_tqdm=True),
):
    seed = randomize_seed_fn(seed, randomize_seed)
    generator = torch.Generator().manual_seed(seed)

    image = pipe(
        prompt=prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
    ).images[0]

    image_path = save_image(image)
    print(image_path)
    return [image_path], seed


examples = [
    "A cat next to a suitcase",
    "A candle on the left of a mouse",
    "A bag on the right of a dog",
    "A mouse on the top of a bowl",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Gallery(label="Result", columns=1, show_label=False)
    with gr.Accordion("Advanced options", open=False):
        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
        )
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        with gr.Row(visible=False):
            width = gr.Slider(
                label="Width",
                minimum=256,
                maximum=MAX_IMAGE_SIZE,
                step=32,
                value=1024,
            )
            height = gr.Slider(
                label="Height",
                minimum=256,
                maximum=MAX_IMAGE_SIZE,
                step=32,
                value=1024,
            )
        with gr.Row():
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=1,
                maximum=20,
                step=0.1,
                value=7.5,
            )
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=10,
                maximum=100,
                step=1,
                value=50,
            )

    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=[result, seed],
        fn=generate,
        cache_examples=CACHE_EXAMPLES,
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate,
        inputs=[prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed],
        outputs=[result, seed],
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()