Spaces:
orrinin
/
Runtime error

orrinin commited on
Commit
694468d
·
verified ·
1 Parent(s): 4efab5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -88
app.py CHANGED
@@ -1,110 +1,216 @@
 
 
 
1
  import gradio as gr
2
- import torch
3
- from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
4
- from huggingface_hub import hf_hub_download
5
- import spaces
6
  from PIL import Image
7
- import requests
8
  from translatepy import Translator
9
 
10
- translator = Translator()
11
-
12
- # Constants
13
- base = "stabilityai/stable-diffusion-xl-base-1.0"
14
- repo = "tianweiy/DMD2"
15
- checkpoints = {
16
- "1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
17
- "4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
18
- }
19
- loaded = None
20
-
21
- CSS = """
22
- .gradio-container {
23
- max-width: 690px !important;
24
- }
25
- footer {
26
- visibility: hidden;
27
- }
28
  """
 
29
 
30
- JS = """function () {
31
- gradioURL = window.location.href
32
- if (!gradioURL.endsWith('?__theme=dark')) {
33
- window.location.replace(gradioURL + '?__theme=dark');
34
- }
35
- }"""
36
-
37
-
38
-
39
- # Ensure model and scheduler are initialized in GPU-enabled function
40
- if torch.cuda.is_available():
41
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
42
- pipe = DiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
43
 
44
 
45
- # Function
46
- @spaces.GPU()
47
- def generate_image(prompt, ckpt="4-Step"):
48
- global loaded
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  prompt = str(translator.translate(prompt, 'English'))
51
-
52
  print(prompt)
53
-
54
- checkpoint = checkpoints[ckpt][0]
55
- num_inference_steps = checkpoints[ckpt][1]
56
-
57
- if loaded != num_inference_steps:
58
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
59
- pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
60
- loaded = num_inference_steps
61
-
62
- if loaded == 1:
63
- timesteps=[399]
64
- else:
65
- timesteps=[999, 749, 499, 249]
66
-
67
- results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, timesteps=timesteps)
68
- return results.images[0]
69
-
 
 
70
 
71
  examples = [
72
- "a cat eating a piece of cheese",
73
- "a ROBOT riding a BLUE horse on Mars, photorealistic",
74
- "Ironman VS Hulk, ultrarealistic",
75
- "a CUTE robot artist painting on an easel",
76
- "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
77
- "An alien holding sign board contain word 'Flash', futuristic, neonpunk",
78
- "Kids going to school, Anime style"
79
  ]
80
 
81
-
82
- # Gradio Interface
83
-
84
- with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
85
- gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
86
- gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center><br><center>Multi-Languages, 4-step is higher quality & 2X slower</center></p>")
 
 
 
87
  with gr.Group():
88
  with gr.Row():
89
- prompt = gr.Textbox(label='Enter Your Prompt', scale=8)
90
- ckpt = gr.Dropdown(label='Steps',choices=['1-Step', '4-Step'], value='4-Step', interactive=True)
91
- submit = gr.Button(scale=1, variant='primary')
92
- img = gr.Image(label='DMD2 Generated Image')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  gr.Examples(
94
  examples=examples,
95
  inputs=prompt,
96
- outputs=img,
97
- fn=generate_image,
98
  cache_examples="lazy",
99
  )
100
-
101
- prompt.submit(fn=generate_image,
102
- inputs=[prompt, ckpt],
103
- outputs=img,
104
- )
105
- submit.click(fn=generate_image,
106
- inputs=[prompt, ckpt],
107
- outputs=img,
108
- )
109
 
110
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
  import gradio as gr
5
+ from gradio_client import Client, file
6
+ import numpy as np
 
 
7
  from PIL import Image
8
+ from typing import Tuple
9
  from translatepy import Translator
10
 
11
+ MODEL = os.environ.get("MODEL")
12
+ API_URL = "https://api-inference.huggingface.co/models/tianweiy/DMD2"
13
+ DESCRIPTION = """
14
+ # DMD2 文生图
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
+ translator = Translator()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
21
+ if randomize_seed:
22
+ seed = random.randint(0, MAX_SEED)
23
+ return seed
24
+
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+
27
+ client = Client(MODEL)
28
+
29
+ style_list = [
30
+ {
31
+ "name": "(无风格)",
32
+ "prompt": "{prompt}",
33
+ },
34
+ {
35
+ "name": "电影",
36
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
37
+ },
38
+ {
39
+ "name": "摄影",
40
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
41
+ },
42
+ {
43
+ "name": "动画",
44
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
45
+ },
46
+ {
47
+ "name": "漫画",
48
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
49
+ },
50
+ {
51
+ "name": "数绘",
52
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
53
+ },
54
+ {
55
+ "name": "像素",
56
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
57
+ },
58
+ {
59
+ "name": "幻想",
60
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
61
+ },
62
+ {
63
+ "name": "朋克",
64
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
65
+ },
66
+ {
67
+ "name": "三维",
68
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
69
+ },
70
+ ]
71
+ styles = {k["name"]: (k["prompt"]) for k in style_list}
72
+ print(styles)
73
+ STYLE_NAMES = list(styles.keys())
74
+ DEFAULT_STYLE_NAME = "(无风格)"
75
+
76
+ def apply_style(style_name: str, positive: str) -> Tuple[str, str]:
77
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
78
+ return p.replace("{prompt}", positive), n
79
+
80
+ def generate(
81
+ prompt: str,
82
+ seed: int = 0,
83
+ width: int = 1024,
84
+ height: int = 1024,
85
+ style: str = DEFAULT_STYLE_NAME,
86
+ num_images: int = 2,
87
+ randomize_seed: bool = False,
88
+ progress=gr.Progress(track_tqdm=True),
89
+ ):
90
  prompt = str(translator.translate(prompt, 'English'))
 
91
  print(prompt)
92
+ seed = int(randomize_seed_fn(seed, randomize_seed))
93
+ # print(client.view_api())
94
+ result = client.predict(
95
+ prompt=prompt,
96
+ seed=seed,
97
+ height=height,
98
+ width=width,
99
+ num_images=num_images,
100
+ fast_vae_decode=True,
101
+ api_name="/inference"
102
+ )
103
+ images = result[0]
104
+ print(images)
105
+ image_paths = []
106
+ # List[Dict(image: filepath, caption: str | None)]
107
+ for img in images:
108
+ image_paths.append(img["image"])
109
+ print(image_paths)
110
+ return image_paths, seed
111
 
112
  examples = [
113
+ "镭射眼的秋田犬",
114
+ "一只吃起司的猫",
115
+ "太空中骑马的宇航员",
116
+ "放学回家的学生们,动画风格",
117
+ "一个可爱的机器人艺术家在画架上绘画,概念艺术",
118
+ "一位女士的特写,她戴着透明、棱柱形、精致的复仇女神头饰,摆出应有的姿势,棕色肤色"
 
119
  ]
120
 
121
+ CSS = '''
122
+ .gradio-container{max-width: 560px !important}
123
+ h1{text-align:center}
124
+ footer {
125
+ visibility: hidden
126
+ }
127
+ '''
128
+ with gr.Blocks(css=CSS, theme="soft") as demo:
129
+ gr.Markdown(DESCRIPTION)
130
  with gr.Group():
131
  with gr.Row():
132
+ prompt = gr.Text(
133
+ label="描述",
134
+ show_label=False,
135
+ max_lines=1,
136
+ placeholder="画什么好呢",
137
+ container=False,
138
+ scale=2,
139
+ )
140
+ run_button = gr.Button("生成", scale=1)
141
+ result = gr.Gallery(label="作品", columns=1, preview=True)
142
+ with gr.Accordion("高级选项", open=False):
143
+ with gr.Row():
144
+ num_images = gr.Slider(
145
+ label="数量",
146
+ minimum=1,
147
+ maximum=5,
148
+ step=1,
149
+ value=2,
150
+ )
151
+ seed = gr.Slider(
152
+ label="种子",
153
+ minimum=0,
154
+ maximum=MAX_SEED,
155
+ step=1,
156
+ value=0,
157
+ visible=True
158
+ )
159
+ randomize_seed = gr.Checkbox(label="随机种子", value=True)
160
+ with gr.Row(visible=True):
161
+ width = gr.Slider(
162
+ label="宽",
163
+ minimum=512,
164
+ maximum=2048,
165
+ step=8,
166
+ value=1024,
167
+ )
168
+ height = gr.Slider(
169
+ label="高",
170
+ minimum=512,
171
+ maximum=2048,
172
+ step=8,
173
+ value=1024,
174
+ )
175
+ with gr.Row(visible=True):
176
+ style_selection = gr.Radio(
177
+ show_label=True,
178
+ container=True,
179
+ interactive=True,
180
+ choices=STYLE_NAMES,
181
+ value=DEFAULT_STYLE_NAME,
182
+ label="风格化",
183
+ )
184
+
185
+
186
  gr.Examples(
187
  examples=examples,
188
  inputs=prompt,
189
+ outputs=[result, seed],
190
+ fn=generate,
191
  cache_examples="lazy",
192
  )
 
 
 
 
 
 
 
 
 
193
 
194
+ gr.on(
195
+ triggers=[
196
+ prompt.submit,
197
+ run_button.click,
198
+ ],
199
+ fn=generate,
200
+ inputs=[
201
+ prompt,
202
+ seed,
203
+ width,
204
+ height,
205
+ style_selection,
206
+ num_images,
207
+ randomize_seed,
208
+ ],
209
+ outputs=[result, seed],
210
+ api_name="run",
211
+ )
212
+
213
+
214
+
215
+ if __name__ == "__main__":
216
+ demo.queue(max_size=20).launch(show_api=False, debug=False)