File size: 4,361 Bytes
b29a22b
1aad4f8
 
 
 
 
a962a1b
aa0ffd0
0a2783f
1aad4f8
7c36275
0bbb523
7c36275
1aad4f8
6b22806
 
 
 
 
 
 
1aad4f8
7c36275
3d03a32
4104e3d
 
 
 
 
d49a403
 
a962a1b
7c36275
 
4104e3d
7c36275
6b22806
d49a403
1aad4f8
0a2783f
7c36275
0a2783f
 
1aad4f8
7c36275
1aad4f8
 
 
 
 
7c36275
aa0ffd0
1aad4f8
 
 
7c36275
 
 
 
0a2783f
7c36275
f51d066
0a2783f
1aad4f8
 
7c36275
 
 
 
 
 
 
 
 
0a2783f
c8f4f3d
 
1aad4f8
 
7c36275
 
 
 
 
 
1aad4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee6f02
1aad4f8
 
 
 
 
 
 
 
 
3d03a32
7817c6c
 
 
 
 
4104e3d
7817c6c
 
 
 
 
 
4104e3d
7817c6c
1aad4f8
7c36275
 
1aad4f8
 
 
7c36275
1aad4f8
7c36275
 
1aad4f8
 
 
7c36275
1aad4f8
 
 
 
 
9e6481f
1aad4f8
 
 
 
 
 
 
 
 
 
7c36275
43cc6f1
1aad4f8
 
 
 
7c36275
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import random

import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline
#import spaces
import uuid

DESCRIPTION = """# SPRIGHT T2I
[SPRIGHT T2I](https://spright-t2i.github.io/) is a framework to improve the spatial consistency of text-to-image models WITHOUT compromising their fidelity aspects.
"""

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
DEFAULT_IMAGE_SIZE = 1024
torch_dtype = torch.float16
if device == "cpu" or device == "mps":
    DEFAULT_IMAGE_SIZE = 512
    torch_dtype = torch.float32


pipe_id = "SPRIGHT-T2I/spright-t2i-sd2"
pipe = DiffusionPipeline.from_pretrained(
    pipe_id,
    torch_dtype=torch_dtype,
    use_safetensors=True,
).to(device)


def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name


def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


#@spaces.GPU
def generate(
    prompt: str,
    seed: int = 0,
    width: int = 768,
    height: int = 768,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    randomize_seed: bool = False,
    progress=gr.Progress(track_tqdm=True),
):
    seed = randomize_seed_fn(seed, randomize_seed)
    generator = torch.Generator().manual_seed(seed)

    image = pipe(
        prompt=prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
    ).images[0]

    image_path = save_image(image)
    print(image_path)
    return [image_path], seed


examples = [
    "A cat next to a suitcase",
    "A candle on the left of a mouse",
    "A bag on the right of a dog",
    "A mouse on the top of a bowl",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Gallery(label="Result", columns=1, show_label=False)
    with gr.Accordion("Advanced options", open=False):
        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
        )
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        with gr.Row():
            width = gr.Slider(
                label="Width",
                minimum=256,
                maximum=MAX_IMAGE_SIZE,
                step=32,
                value=DEFAULT_IMAGE_SIZE,
            )
            height = gr.Slider(
                label="Height",
                minimum=256,
                maximum=MAX_IMAGE_SIZE,
                step=32,
                value=DEFAULT_IMAGE_SIZE,
            )
        with gr.Row():
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=1,
                maximum=20,
                step=0.1,
                value=7.5,
            )
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=10,
                maximum=100,
                step=1,
                value=50,
            )

    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=[result, seed],
        fn=generate,
        cache_examples=CACHE_EXAMPLES,
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate,
        inputs=[prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed],
        outputs=[result, seed],
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()