radames commited on
Commit
7f871a4
·
1 Parent(s): cb20fc8
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +254 -0
  3. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ gradio_cached_examples
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ from typing import List
8
+ from diffusers.utils import numpy_to_pil
9
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
+ from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
+ from fastapi import FastAPI
12
+ import uvicorn
13
+ from pydantic import BaseModel
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import RedirectResponse
16
+
17
+
18
+ class GenerateRequest(BaseModel):
19
+ prompt: str
20
+ negative_prompt: str = ""
21
+ seed: int = 0
22
+
23
+
24
+ app = FastAPI()
25
+ origins = [
26
+ "http://localhost.tiangolo.com",
27
+ "https://localhost.tiangolo.com",
28
+ "http://localhost",
29
+ "http://localhost:8080",
30
+ ]
31
+
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=origins,
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+
41
+ @app.get("/")
42
+ async def main():
43
+ # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade
44
+ return RedirectResponse("https://huggingface.co/spaces/multimodalart/stable-cascade")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ uvicorn.run(app, host="0.0.0.0", port=7860)
49
+
50
+ # MAX_SEED = np.iinfo(np.int32).max
51
+ # USE_TORCH_COMPILE = False
52
+
53
+ # dtype = torch.bfloat16
54
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
55
+ # if torch.cuda.is_available():
56
+ # prior_pipeline = StableCascadePriorPipeline.from_pretrained(
57
+ # "stabilityai/stable-cascade-prior", torch_dtype=dtype) # .to(device)
58
+ # decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
59
+ # "stabilityai/stable-cascade", torch_dtype=dtype) # .to(device)
60
+ # prior_pipeline.to(device)
61
+ # decoder_pipeline.to(device)
62
+
63
+ # if USE_TORCH_COMPILE:
64
+ # prior_pipeline.prior = torch.compile(
65
+ # prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
66
+ # decoder_pipeline.decoder = torch.compile(
67
+ # decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
68
+
69
+
70
+ # else:
71
+ # prior_pipeline = None
72
+ # decoder_pipeline = None
73
+
74
+
75
+ # def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
76
+ # if randomize_seed:
77
+ # seed = random.randint(0, MAX_SEED)
78
+ # return seed
79
+
80
+
81
+ # def generate(
82
+ # prompt: str,
83
+ # negative_prompt: str = "",
84
+ # seed: int = 0,
85
+ # width: int = 1024,
86
+ # height: int = 1024,
87
+ # prior_num_inference_steps: int = 30,
88
+ # # prior_timesteps: List[float] = None,
89
+ # prior_guidance_scale: float = 4.0,
90
+ # decoder_num_inference_steps: int = 12,
91
+ # # decoder_timesteps: List[float] = None,
92
+ # decoder_guidance_scale: float = 0.0,
93
+ # num_images_per_prompt: int = 2,
94
+ # progress=gr.Progress(track_tqdm=True),
95
+ # ) -> PIL.Image.Image:
96
+
97
+ # generator = torch.Generator().manual_seed(seed)
98
+ # prior_output = prior_pipeline(
99
+ # prompt=prompt,
100
+ # height=height,
101
+ # width=width,
102
+ # num_inference_steps=prior_num_inference_steps,
103
+ # timesteps=DEFAULT_STAGE_C_TIMESTEPS,
104
+ # negative_prompt=negative_prompt,
105
+ # guidance_scale=prior_guidance_scale,
106
+ # num_images_per_prompt=num_images_per_prompt,
107
+ # generator=generator,
108
+ # )
109
+ # decoder_output = decoder_pipeline(
110
+ # image_embeddings=prior_output.image_embeddings,
111
+ # prompt=prompt,
112
+ # num_inference_steps=decoder_num_inference_steps,
113
+ # # timesteps=decoder_timesteps,
114
+ # guidance_scale=decoder_guidance_scale,
115
+ # negative_prompt=negative_prompt,
116
+ # generator=generator,
117
+ # output_type="pil",
118
+ # ).images
119
+
120
+ # return decoder_output[0]
121
+
122
+
123
+ # examples = [
124
+ # "An astronaut riding a green horse",
125
+ # "A mecha robot in a favela by Tarsila do Amaral",
126
+ # "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
127
+ # "A delicious feijoada ramen dish"
128
+ # ]
129
+
130
+ # with gr.Blocks() as demo:
131
+ # gr.Markdown(DESCRIPTION)
132
+ # gr.DuplicateButton(
133
+ # value="Duplicate Space for private use",
134
+ # elem_id="duplicate-button",
135
+ # visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
136
+ # )
137
+ # with gr.Group():
138
+ # with gr.Row():
139
+ # prompt = gr.Text(
140
+ # label="Prompt",
141
+ # show_label=False,
142
+ # max_lines=1,
143
+ # placeholder="Enter your prompt",
144
+ # container=False,
145
+ # )
146
+ # run_button = gr.Button("Run", scale=0)
147
+ # result = gr.Image(label="Result", show_label=False)
148
+ # with gr.Accordion("Advanced options", open=False):
149
+ # negative_prompt = gr.Text(
150
+ # label="Negative prompt",
151
+ # max_lines=1,
152
+ # placeholder="Enter a Negative Prompt",
153
+ # )
154
+
155
+ # seed = gr.Slider(
156
+ # label="Seed",
157
+ # minimum=0,
158
+ # maximum=MAX_SEED,
159
+ # step=1,
160
+ # value=0,
161
+ # )
162
+ # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
163
+ # with gr.Row():
164
+ # width = gr.Slider(
165
+ # label="Width",
166
+ # minimum=1024,
167
+ # maximum=1536,
168
+ # step=512,
169
+ # value=1024,
170
+ # )
171
+ # height = gr.Slider(
172
+ # label="Height",
173
+ # minimum=1024,
174
+ # maximum=1536,
175
+ # step=512,
176
+ # value=1024,
177
+ # )
178
+ # num_images_per_prompt = gr.Slider(
179
+ # label="Number of Images",
180
+ # minimum=1,
181
+ # maximum=2,
182
+ # step=1,
183
+ # value=1,
184
+ # )
185
+ # with gr.Row():
186
+ # prior_guidance_scale = gr.Slider(
187
+ # label="Prior Guidance Scale",
188
+ # minimum=0,
189
+ # maximum=20,
190
+ # step=0.1,
191
+ # value=4.0,
192
+ # )
193
+ # prior_num_inference_steps = gr.Slider(
194
+ # label="Prior Inference Steps",
195
+ # minimum=10,
196
+ # maximum=30,
197
+ # step=1,
198
+ # value=20,
199
+ # )
200
+
201
+ # decoder_guidance_scale = gr.Slider(
202
+ # label="Decoder Guidance Scale",
203
+ # minimum=0,
204
+ # maximum=0,
205
+ # step=0.1,
206
+ # value=0.0,
207
+ # )
208
+ # decoder_num_inference_steps = gr.Slider(
209
+ # label="Decoder Inference Steps",
210
+ # minimum=4,
211
+ # maximum=12,
212
+ # step=1,
213
+ # value=10,
214
+ # )
215
+
216
+ # gr.Examples(
217
+ # examples=examples,
218
+ # inputs=prompt,
219
+ # outputs=result,
220
+ # fn=generate,
221
+ # cache_examples=False,
222
+ # )
223
+
224
+ # inputs = [
225
+ # prompt,
226
+ # negative_prompt,
227
+ # seed,
228
+ # width,
229
+ # height,
230
+ # prior_num_inference_steps,
231
+ # # prior_timesteps,
232
+ # prior_guidance_scale,
233
+ # decoder_num_inference_steps,
234
+ # # decoder_timesteps,
235
+ # decoder_guidance_scale,
236
+ # num_images_per_prompt,
237
+ # ]
238
+ # gr.on(
239
+ # triggers=[prompt.submit, negative_prompt.submit, run_button.click],
240
+ # fn=randomize_seed_fn,
241
+ # inputs=[seed, randomize_seed],
242
+ # outputs=seed,
243
+ # queue=False,
244
+ # api_name=False,
245
+ # ).then(
246
+ # fn=generate,
247
+ # inputs=inputs,
248
+ # outputs=result,
249
+ # api_name="run",
250
+ # )
251
+
252
+
253
+ # if __name__ == "__main__":
254
+ # demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/kashif/diffusers.git@wuerstchen-v3
2
+ accelerate
3
+ safetensors
4
+ transformers
5
+ gradio
6
+ fastapi
7
+ pydantic
8
+ uvicorn