p1atdev commited on
Commit
70f55b7
·
1 Parent(s): a4bee0b

chore: lpw, image size

Browse files
Files changed (1) hide show
  1. app.py +48 -12
app.py CHANGED
@@ -1,11 +1,16 @@
 
 
1
  import os
2
- import gradio as gr
3
- import numpy as np
4
  import random
5
- import spaces
 
 
 
 
6
  from diffusers import DiffusionPipeline
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
- import torch
 
9
 
10
  try:
11
  from dotenv import load_dotenv
@@ -38,12 +43,16 @@ dart = dart.requires_grad_(False)
38
  dart = torch.compile(dart)
39
  tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
40
 
41
- pipe = DiffusionPipeline.from_pretrained(IMAGE_MODEL_REPO_ID, torch_dtype=torch_dtype)
 
 
 
 
42
  pipe = pipe.to(device)
43
 
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
46
- MAX_IMAGE_SIZE = 1024
47
 
48
  TEMPLATE = (
49
  "<|bos|>"
@@ -59,6 +68,20 @@ TEMPLATE = (
59
  "<general>"
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  @torch.inference_mode
64
  def generate_prompt(aspect_ratio: str):
@@ -66,7 +89,7 @@ def generate_prompt(aspect_ratio: str):
66
  TEMPLATE.format(aspect_ratio=aspect_ratio),
67
  return_tensors="pt",
68
  ).input_ids
69
- print("input_ids", input_ids)
70
 
71
  output_ids = dart.generate(
72
  input_ids,
@@ -80,10 +103,13 @@ def generate_prompt(aspect_ratio: str):
80
 
81
  generated = output_ids[len(input_ids) :]
82
  decoded = ", ".join([token for token in tokenizer.batch_decode(generated, skip_special_tokens=True) if token.strip() != ""])
83
- print("decoded", decoded)
84
 
85
  return decoded
86
 
 
 
 
87
  @spaces.GPU
88
  def generate_image(
89
  prompt: str,
@@ -93,7 +119,6 @@ def generate_image(
93
  height: int,
94
  guidance_scale: float,
95
  num_inference_steps: int,
96
- progress=gr.Progress(track_tqdm=True),
97
  ):
98
  image = pipe(
99
  prompt=prompt,
@@ -108,6 +133,7 @@ def generate_image(
108
  return image
109
 
110
  def on_generate(
 
111
  negative_prompt: str,
112
  seed,
113
  randomize_seed,
@@ -115,12 +141,15 @@ def on_generate(
115
  height,
116
  guidance_scale,
117
  num_inference_steps,
 
118
  ):
119
  if randomize_seed:
120
  seed = random.randint(0, MAX_SEED)
121
  generator = torch.Generator().manual_seed(seed)
122
 
123
- prompt = generate_prompt("<|aspect_ratio:square|>")
 
 
124
  print(prompt)
125
 
126
  image = generate_image(
@@ -155,15 +184,21 @@ with gr.Blocks(css=css) as demo:
155
  result = gr.Image(label="Result", show_label=False)
156
 
157
  with gr.Accordion("Generation details", open=False):
158
- prompt_txt = gr.Textbox("Generated prompt", interactive=False)
159
 
160
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
161
  negative_prompt = gr.Text(
162
  label="Negative prompt",
163
  max_lines=1,
164
  placeholder="Enter a negative prompt",
165
  visible=False,
166
- value=" worst quality, comic, multiple views, bad quality, low quality, lowres, displeasing, very displeasing, bad anatomy, bad hands, scan artifacts, monochrome, greyscale, signature, twitter username, jpeg artifacts, 2koma, 4koma, guro, extra digits, fewer digits",
167
  )
168
 
169
  seed = gr.Slider(
@@ -214,6 +249,7 @@ with gr.Blocks(css=css) as demo:
214
  triggers=[run_button.click],
215
  fn=on_generate,
216
  inputs=[
 
217
  negative_prompt,
218
  seed,
219
  randomize_seed,
 
1
+ import spaces
2
+
3
  import os
 
 
4
  import random
5
+ import math
6
+
7
+ import torch
8
+ import numpy as np
9
+
10
  from diffusers import DiffusionPipeline
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ import gradio as gr
14
 
15
  try:
16
  from dotenv import load_dotenv
 
43
  dart = torch.compile(dart)
44
  tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
45
 
46
+ pipe = DiffusionPipeline.from_pretrained(
47
+ IMAGE_MODEL_REPO_ID,
48
+ torch_dtype=torch_dtype,
49
+ custom_pipeline="lpw_stable_diffusion_xl"
50
+ )
51
  pipe = pipe.to(device)
52
 
53
 
54
  MAX_SEED = np.iinfo(np.int32).max
55
+ MAX_IMAGE_SIZE = 2048
56
 
57
  TEMPLATE = (
58
  "<|bos|>"
 
68
  "<general>"
69
  )
70
 
71
+ def get_aspect_ratio(width: int, height: int) -> str:
72
+ ar = width / height
73
+
74
+ if ar <= 1 / math.sqrt(3):
75
+ return "<|aspect_ratio:ultra_wide|>"
76
+ elif ar <= 8 / 9: #
77
+ return "<|aspect_ratio:wide|>"
78
+ elif ar < 9 / 8:
79
+ return "<|aspect_ratio:square|>"
80
+ elif ar < math.sqrt(3):
81
+ return "<|aspect_ratio:tall|>"
82
+ else:
83
+ return "<|aspect_ratio:ultra_tall|>"
84
+
85
 
86
  @torch.inference_mode
87
  def generate_prompt(aspect_ratio: str):
 
89
  TEMPLATE.format(aspect_ratio=aspect_ratio),
90
  return_tensors="pt",
91
  ).input_ids
92
+ print("input_ids:", input_ids)
93
 
94
  output_ids = dart.generate(
95
  input_ids,
 
103
 
104
  generated = output_ids[len(input_ids) :]
105
  decoded = ", ".join([token for token in tokenizer.batch_decode(generated, skip_special_tokens=True) if token.strip() != ""])
106
+ print("decoded:", decoded)
107
 
108
  return decoded
109
 
110
+ def format_prompt(prompt: str, prompt_suffix: str):
111
+ return f"{prompt}, {prompt_suffix}"
112
+
113
  @spaces.GPU
114
  def generate_image(
115
  prompt: str,
 
119
  height: int,
120
  guidance_scale: float,
121
  num_inference_steps: int,
 
122
  ):
123
  image = pipe(
124
  prompt=prompt,
 
133
  return image
134
 
135
  def on_generate(
136
+ suffix: str,
137
  negative_prompt: str,
138
  seed,
139
  randomize_seed,
 
141
  height,
142
  guidance_scale,
143
  num_inference_steps,
144
+ progress=gr.Progress(track_tqdm=True),
145
  ):
146
  if randomize_seed:
147
  seed = random.randint(0, MAX_SEED)
148
  generator = torch.Generator().manual_seed(seed)
149
 
150
+ ar = get_aspect_ratio(width, height)
151
+ prompt = generate_prompt(ar)
152
+ prompt = format_prompt(prompt, suffix)
153
  print(prompt)
154
 
155
  image = generate_image(
 
184
  result = gr.Image(label="Result", show_label=False)
185
 
186
  with gr.Accordion("Generation details", open=False):
187
+ prompt_txt = gr.Textbox(label="Generated prompt", interactive=False)
188
 
189
  with gr.Accordion("Advanced Settings", open=False):
190
+ prompt_suffix = gr.Text(
191
+ label="Prompt suffix",
192
+ max_lines=1,
193
+ visible=False,
194
+ value="masterpiece, best quality",
195
+ )
196
  negative_prompt = gr.Text(
197
  label="Negative prompt",
198
  max_lines=1,
199
  placeholder="Enter a negative prompt",
200
  visible=False,
201
+ value="worst quality, bad quality, low quality, lowres, displeasing, very displeasing, bad anatomy, bad hands, scan artifacts, signature, username, jpeg artifacts, guro, extra digits, fewer digits",
202
  )
203
 
204
  seed = gr.Slider(
 
249
  triggers=[run_button.click],
250
  fn=on_generate,
251
  inputs=[
252
+ prompt_suffix,
253
  negative_prompt,
254
  seed,
255
  randomize_seed,