p1atdev commited on
Commit
a1db0e9
1 Parent(s): 7081246

feat: dartrs backend

Browse files
Files changed (6) hide show
  1. app.py +64 -75
  2. diffusion.py +4 -1
  3. lpw_pipeline_xl.py +0 -0
  4. requirements.txt +2 -1
  5. utils.py +15 -0
  6. v2.py +57 -76
app.py CHANGED
@@ -6,17 +6,28 @@ import gradio as gr
6
  from v2 import V2UI
7
  from diffusion import ImageGenerator
8
  from output import UpsamplingOutput
9
- from utils import QUALITY_TAGS, NEGATIVE_PROMPT, IMAGE_SIZE_OPTIONS, IMAGE_SIZES
10
 
11
 
12
  def animagine_xl_v3_1(output: UpsamplingOutput):
 
 
 
 
 
 
 
 
 
 
13
  return ", ".join(
14
  [
15
  part.strip()
16
  for part in [
 
17
  output.character_tags,
18
  output.copyright_tags,
19
- output.general_tags,
20
  output.upsampled_tags,
21
  (
22
  output.rating_tag
@@ -35,59 +46,29 @@ def elapsed_time_format(elapsed_time: float) -> str:
35
 
36
  def parse_upsampling_output(
37
  upsampler: Callable[..., UpsamplingOutput],
38
- image_generator: Callable[..., Image.Image],
39
  ):
40
- def _parse_upsampling_output(
41
- generate_image: bool, *args
42
- ) -> tuple[str, str, Image.Image | None]:
 
 
43
  output = upsampler(*args)
44
 
45
  print(output)
46
 
47
- if not generate_image:
48
- return (
49
- animagine_xl_v3_1(output),
50
- elapsed_time_format(output.elapsed_time),
51
- None,
52
- )
53
-
54
- # generate image
55
- [
56
- image_size_option,
57
- quality_tags,
58
- negative_prompt,
59
- num_inference_steps,
60
- guidance_scale,
61
- ] = args[
62
- 7:
63
- ] # remove the first 7 arguments for upsampler
64
- width, height = IMAGE_SIZES[image_size_option]
65
- image = image_generator(
66
- ", ".join([animagine_xl_v3_1(output), quality_tags]),
67
- negative_prompt,
68
- height,
69
- width,
70
- num_inference_steps,
71
- guidance_scale,
72
- )
73
-
74
  return (
75
  animagine_xl_v3_1(output),
76
  elapsed_time_format(output.elapsed_time),
77
- image,
 
 
78
  )
79
 
80
  return _parse_upsampling_output
81
 
82
 
83
- def toggle_visible_output_image(generate_image: bool):
84
- return gr.update(
85
- visible=generate_image,
86
- )
87
-
88
-
89
  def image_generation_config_ui():
90
- with gr.Accordion(label="Image generation config", open=True) as accordion:
91
  image_size = gr.Radio(
92
  label="Image size",
93
  choices=list(IMAGE_SIZE_OPTIONS.keys()),
@@ -142,7 +123,7 @@ def main():
142
  v2 = V2UI()
143
 
144
  print("Loading diffusion model...")
145
- image_generator = ImageGenerator()
146
  print("Loaded.")
147
 
148
  with gr.Blocks() as ui:
@@ -152,25 +133,25 @@ def main():
152
  with gr.Column():
153
  v2.ui()
154
 
155
- generate_image_check = gr.Checkbox(
156
- label="Also generate image", value=True
 
 
 
 
 
157
  )
158
 
159
  accordion, image_generation_config_components = (
160
  image_generation_config_ui()
161
  )
162
 
163
- with gr.Column():
164
- output_text = gr.TextArea(label="Output tags", interactive=False)
165
-
166
- elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
167
-
168
  output_image = gr.Gallery(
169
  label="Output image",
170
  columns=1,
171
  preview=True,
172
  show_label=False,
173
- visible=True,
174
  )
175
 
176
  gr.Examples(
@@ -179,78 +160,86 @@ def main():
179
  "original",
180
  "",
181
  "1girl, solo, blue theme, limited palette",
182
- "lax",
 
183
  "long",
184
- "1536x640",
185
  ],
186
  [
187
  "",
188
  "",
189
  "4girls",
190
- "none",
 
191
  "very_long",
192
- "768x1344",
 
 
 
 
 
 
 
 
 
193
  ],
194
  [
195
  "",
196
  "",
197
  "no humans, scenery, spring (season)",
198
- "none",
 
199
  "medium",
200
- "1536x640",
201
  ],
202
  [
203
  "sousou no frieren",
204
  "frieren",
205
  "1girl, solo",
206
- "none",
 
207
  "long",
208
- "768x1344",
209
  ],
210
  [
211
  "honkai: star rail",
212
  "silver wolf (honkai: star rail)",
213
  "1girl, solo, annoyed",
214
- "none",
 
215
  "long",
216
- "768x1344",
217
  ],
218
  [
219
  "bocchi the rock!",
220
  "gotoh hitori, kita ikuyo, ijichi nijika, yamada ryo",
221
  "4girls, multiple girls",
222
- "none",
 
223
  "very_long",
224
- "1344x768",
225
  ],
226
  [
227
  "chuunibyou demo koi ga shitai!",
228
  "takanashi rikka",
229
  "1girl, solo",
230
- "none",
 
231
  "long",
232
- "640x1536",
233
  ],
234
  ],
235
  inputs=[
236
- *v2.get_inputs()[1:6],
237
- image_generation_config_components[0], # image size
238
  ],
239
  )
240
 
241
  v2.get_generate_btn().click(
242
- parse_upsampling_output(v2.on_generate, image_generator.generate),
243
  inputs=[
244
- generate_image_check,
245
  *v2.get_inputs(),
246
- *image_generation_config_components,
247
  ],
248
- outputs=[output_text, elapsed_time_md, output_image],
249
- )
250
- generate_image_check.change(
251
- toggle_visible_output_image,
252
- inputs=[generate_image_check],
253
- outputs=[output_image],
254
  )
255
 
256
  ui.launch()
 
6
  from v2 import V2UI
7
  from diffusion import ImageGenerator
8
  from output import UpsamplingOutput
9
+ from utils import QUALITY_TAGS, NEGATIVE_PROMPT, IMAGE_SIZE_OPTIONS, PEOPLE_TAGS
10
 
11
 
12
  def animagine_xl_v3_1(output: UpsamplingOutput):
13
+ # separate people tags (e.g. 1girl)
14
+ people_tags = []
15
+ other_general_tags = []
16
+ for tag in output.general_tags.split(","):
17
+ tag = tag.strip()
18
+ if tag in PEOPLE_TAGS:
19
+ people_tags.append(tag)
20
+ else:
21
+ other_general_tags.append(tag)
22
+
23
  return ", ".join(
24
  [
25
  part.strip()
26
  for part in [
27
+ *people_tags,
28
  output.character_tags,
29
  output.copyright_tags,
30
+ *other_general_tags,
31
  output.upsampled_tags,
32
  (
33
  output.rating_tag
 
46
 
47
  def parse_upsampling_output(
48
  upsampler: Callable[..., UpsamplingOutput],
 
49
  ):
50
+ def _parse_upsampling_output(*args) -> tuple[
51
+ str,
52
+ str,
53
+ dict,
54
+ ]:
55
  output = upsampler(*args)
56
 
57
  print(output)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  return (
60
  animagine_xl_v3_1(output),
61
  elapsed_time_format(output.elapsed_time),
62
+ gr.update(
63
+ interactive=True,
64
+ ),
65
  )
66
 
67
  return _parse_upsampling_output
68
 
69
 
 
 
 
 
 
 
70
  def image_generation_config_ui():
71
+ with gr.Accordion(label="Image generation config", open=False) as accordion:
72
  image_size = gr.Radio(
73
  label="Image size",
74
  choices=list(IMAGE_SIZE_OPTIONS.keys()),
 
123
  v2 = V2UI()
124
 
125
  print("Loading diffusion model...")
126
+ # image_generator = ImageGenerator()
127
  print("Loaded.")
128
 
129
  with gr.Blocks() as ui:
 
133
  with gr.Column():
134
  v2.ui()
135
 
136
+ with gr.Column():
137
+ output_text = gr.TextArea(label="Output tags", interactive=False)
138
+
139
+ elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
140
+
141
+ generate_image_btn = gr.Button(
142
+ value="Generate image with this prompt!",
143
  )
144
 
145
  accordion, image_generation_config_components = (
146
  image_generation_config_ui()
147
  )
148
 
 
 
 
 
 
149
  output_image = gr.Gallery(
150
  label="Output image",
151
  columns=1,
152
  preview=True,
153
  show_label=False,
154
+ visible=False,
155
  )
156
 
157
  gr.Examples(
 
160
  "original",
161
  "",
162
  "1girl, solo, blue theme, limited palette",
163
+ "sfw",
164
+ "ultra_wide",
165
  "long",
166
+ "lax",
167
  ],
168
  [
169
  "",
170
  "",
171
  "4girls",
172
+ "sfw",
173
+ "tall",
174
  "very_long",
175
+ "lax",
176
+ ],
177
+ [
178
+ "original",
179
+ "",
180
+ "1girl, solo, upper body, looking at viewer, profile picture",
181
+ "sfw",
182
+ "square",
183
+ "medium",
184
+ "none",
185
  ],
186
  [
187
  "",
188
  "",
189
  "no humans, scenery, spring (season)",
190
+ "general",
191
+ "ultra_wide",
192
  "medium",
193
+ "lax",
194
  ],
195
  [
196
  "sousou no frieren",
197
  "frieren",
198
  "1girl, solo",
199
+ "general",
200
+ "tall",
201
  "long",
202
+ "lax",
203
  ],
204
  [
205
  "honkai: star rail",
206
  "silver wolf (honkai: star rail)",
207
  "1girl, solo, annoyed",
208
+ "sfw",
209
+ "tall",
210
  "long",
211
+ "lax",
212
  ],
213
  [
214
  "bocchi the rock!",
215
  "gotoh hitori, kita ikuyo, ijichi nijika, yamada ryo",
216
  "4girls, multiple girls",
217
+ "sfw",
218
+ "ultra_wide",
219
  "very_long",
220
+ "lax",
221
  ],
222
  [
223
  "chuunibyou demo koi ga shitai!",
224
  "takanashi rikka",
225
  "1girl, solo",
226
+ "sfw",
227
+ "ultra_tall",
228
  "long",
229
+ "lax",
230
  ],
231
  ],
232
  inputs=[
233
+ *v2.get_inputs()[1:8],
 
234
  ],
235
  )
236
 
237
  v2.get_generate_btn().click(
238
+ parse_upsampling_output(v2.on_generate),
239
  inputs=[
 
240
  *v2.get_inputs(),
 
241
  ],
242
+ outputs=[output_text, elapsed_time_md, generate_image_btn],
 
 
 
 
 
243
  )
244
 
245
  ui.launch()
diffusion.py CHANGED
@@ -22,6 +22,9 @@ except ImportError:
22
  from utils import NEGATIVE_PROMPT
23
 
24
 
 
 
 
25
  class ImageGenerator:
26
  pipe: StableDiffusionXLPipeline
27
 
@@ -41,7 +44,7 @@ class ImageGenerator:
41
  # sdpa
42
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
43
 
44
- self.pipe.to("cuda")
45
 
46
  try:
47
  self.pipe = torch.compile(self.pipe)
 
22
  from utils import NEGATIVE_PROMPT
23
 
24
 
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+
28
  class ImageGenerator:
29
  pipe: StableDiffusionXLPipeline
30
 
 
44
  # sdpa
45
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
46
 
47
+ self.pipe.to(device)
48
 
49
  try:
50
  self.pipe = torch.compile(self.pipe)
lpw_pipeline_xl.py ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -3,4 +3,5 @@ accelerate==0.29.2
3
  transformers==4.38.2
4
  optimum[onnxruntime]==1.19.1
5
  diffusers==0.27.2
6
- spaces==0.26.2
 
 
3
  transformers==4.38.2
4
  optimum[onnxruntime]==1.19.1
5
  diffusers==0.27.2
6
+ spaces==0.26.2
7
+ git+https://github.com/p1atdev/dartrs.git@33cdcfe77f236ba286ad60e10db8a5650e150fd2
utils.py CHANGED
@@ -22,6 +22,13 @@ IMAGE_SIZES = {
22
  "640x1536": (640, 1536),
23
  }
24
 
 
 
 
 
 
 
 
25
  RATING_OPTIONS = {
26
  "sfw": "<|rating:sfw|>",
27
  "general": "<|rating:general|>",
@@ -42,3 +49,11 @@ IDENTITY_OPTIONS = {
42
  "lax": "<|identity:lax|>",
43
  "strict": "<|identity:strict|>",
44
  }
 
 
 
 
 
 
 
 
 
22
  "640x1536": (640, 1536),
23
  }
24
 
25
+ ASPECT_RATIO_OPTIONS = {
26
+ "ultra_wide": "<|aspect_ratio:ultra_wide|>",
27
+ "wide": "<|aspect_ratio:wide|>",
28
+ "square": "<|aspect_ratio:square|>",
29
+ "tall": "<|aspect_ratio:tall|>",
30
+ "ultra_tall": "<|aspect_ratio:ultra_tall|>",
31
+ }
32
  RATING_OPTIONS = {
33
  "sfw": "<|rating:sfw|>",
34
  "general": "<|rating:general|>",
 
49
  "lax": "<|identity:lax|>",
50
  "strict": "<|identity:strict|>",
51
  }
52
+
53
+
54
+ PEOPLE_TAGS = [
55
+ *[f"1{x}" for x in ["girl", "boy", "other"]],
56
+ *[f"{i}girls" for i in range(2, 6)],
57
+ *[f"6+{x}s" for x in ["girl", "boy", "other"]],
58
+ "no humans",
59
+ ]
v2.py CHANGED
@@ -1,7 +1,11 @@
1
  import time
2
 
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
 
 
 
 
5
 
6
  import gradio as gr
7
  from gradio.components import Component
@@ -16,31 +20,26 @@ except ImportError:
16
 
17
 
18
  from output import UpsamplingOutput
19
- from utils import IMAGE_SIZE_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
20
 
21
  ALL_MODELS = {
22
- "dart-v2-llama-100m-sft": {
23
- "repo": "p1atdev/dart-v2-llama-100m-sft",
24
- "type": "sft",
25
- },
26
- "dart-v2-mistral-100m-sft": {
27
- "repo": "p1atdev/dart-v2-mistral-100m-sft",
28
  "type": "sft",
 
29
  },
30
- "dart-v2-mixtral-160m-sft": {
31
- "repo": "p1atdev/dart-v2-mixtral-160m-sft",
32
  "type": "sft",
 
33
  },
34
  }
35
 
36
 
37
- def prepare_models(model_name: str):
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
- model = AutoModelForCausalLM.from_pretrained(
40
- model_name,
41
- torch_dtype=torch.bfloat16,
42
- device_map="auto",
43
- )
44
 
45
  return {
46
  "tokenizer": tokenizer,
@@ -48,21 +47,21 @@ def prepare_models(model_name: str):
48
  }
49
 
50
 
51
- def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
52
- """Just remove unk tokens."""
53
- return ", ".join(
54
- tokenizer.batch_decode(
55
- [
56
- token
57
- for token in tokenizer.encode_plus(
58
- tags.strip(),
59
- return_tensors="pt",
60
- ).input_ids[0]
61
- if int(token) != tokenizer.unk_token_id
62
- ],
63
- skip_special_tokens=True,
64
- )
65
- )
66
 
67
 
68
  def compose_prompt(
@@ -88,46 +87,28 @@ def compose_prompt(
88
  @torch.no_grad()
89
  @spaces.GPU(duration=5)
90
  def generate_tags(
91
- model,
92
- tokenizer: PreTrainedTokenizerBase,
93
  prompt: str,
94
  ):
95
- print( # debug
96
- tokenizer.tokenize(
97
- prompt,
98
- add_special_tokens=False,
99
- )
100
- )
101
- input_ids = tokenizer.encode_plus(prompt, return_tensors="pt").input_ids
102
  output = model.generate(
103
- input_ids.to(model.device),
104
- do_sample=True,
105
- temperature=1,
106
- top_p=0.9,
107
- top_k=100,
108
- num_beams=1,
109
- num_return_sequences=1,
110
- max_length=256,
111
  )
112
 
113
- # remove input tokens
114
- pure_output_ids = output[0][len(input_ids[0]) :]
115
-
116
- return ", ".join(
117
- [
118
- token
119
- for token in tokenizer.batch_decode(
120
- pure_output_ids, skip_special_tokens=True
121
- )
122
- if token.strip() != ""
123
- ]
124
- )
125
 
126
 
127
  class V2UI:
128
  model_name: str | None = None
129
- model: AutoModelForCausalLM
130
- tokenizer: PreTrainedTokenizerBase
131
 
132
  input_components: list[Component] = []
133
  generate_btn: gr.Button
@@ -139,25 +120,25 @@ class V2UI:
139
  character_tags: str,
140
  general_tags: str,
141
  rating_option: str,
142
- # aspect_ratio_option: str,
143
  length_option: str,
144
  identity_option: str,
145
- image_size: str, # this is from image generation config
146
  *args,
147
  ) -> UpsamplingOutput:
148
  if self.model_name is None or self.model_name != model_name:
149
- models = prepare_models(ALL_MODELS[model_name]["repo"])
150
  self.model = models["model"]
151
  self.tokenizer = models["tokenizer"]
152
  self.model_name = model_name
153
 
154
  # normalize tags
155
- copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
156
- character_tags = normalize_tags(self.tokenizer, character_tags)
157
- general_tags = normalize_tags(self.tokenizer, general_tags)
158
 
159
  rating_tag = RATING_OPTIONS[rating_option]
160
- aspect_ratio_tag = IMAGE_SIZE_OPTIONS[image_size]
161
  length_tag = LENGTH_OPTIONS[length_option]
162
  identity_tag = IDENTITY_OPTIONS[identity_option]
163
 
@@ -212,11 +193,11 @@ class V2UI:
212
  choices=list(RATING_OPTIONS.keys()),
213
  value="general",
214
  )
215
- # input_aspect_ratio = gr.Radio(
216
- # label="Aspect ratio",
217
- # choices=["ultra_wide", "wide", "square", "tall", "ultra_tall"],
218
- # value="tall",
219
- # )
220
  input_length = gr.Radio(
221
  label="Length",
222
  choices=list(LENGTH_OPTIONS.keys()),
@@ -242,7 +223,7 @@ class V2UI:
242
  input_character,
243
  input_general,
244
  input_rating,
245
- # input_aspect_ratio,
246
  input_length,
247
  input_identity,
248
  ]
 
1
  import time
2
 
3
  import torch
4
+
5
+ from dartrs.v2 import V2Model, MixtralModel
6
+ from dartrs.dartrs import DartTokenizer
7
+ from dartrs.utils import get_generation_config
8
+
9
 
10
  import gradio as gr
11
  from gradio.components import Component
 
20
 
21
 
22
  from output import UpsamplingOutput
23
+ from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
24
 
25
  ALL_MODELS = {
26
+ "dart-v2-mixtral-160m-sft-6": {
27
+ "repo": "p1atdev/dart-v2-mixtral-160m-sft-6",
 
 
 
 
28
  "type": "sft",
29
+ "class": MixtralModel,
30
  },
31
+ "dart-v2-mixtral-160m-sft-8": {
32
+ "repo": "p1atdev/dart-v2-mixtral-160m-sft-8",
33
  "type": "sft",
34
+ "class": MixtralModel,
35
  },
36
  }
37
 
38
 
39
+ def prepare_models(model_config: dict):
40
+ model_name = model_config["repo"]
41
+ tokenizer = DartTokenizer.from_pretrained(model_name)
42
+ model = model_config["class"].from_pretrained(model_name)
 
 
 
43
 
44
  return {
45
  "tokenizer": tokenizer,
 
47
  }
48
 
49
 
50
+ # def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
51
+ # """Just remove unk tokens."""
52
+ # return ", ".join(
53
+ # tokenizer.batch_decode(
54
+ # [
55
+ # token
56
+ # for token in tokenizer.encode_plus(
57
+ # tags.strip(),
58
+ # return_tensors="pt",
59
+ # ).input_ids[0]
60
+ # if int(token) != tokenizer.unk_token_id
61
+ # ],
62
+ # skip_special_tokens=True,
63
+ # )
64
+ # )
65
 
66
 
67
  def compose_prompt(
 
87
  @torch.no_grad()
88
  @spaces.GPU(duration=5)
89
  def generate_tags(
90
+ model: V2Model,
91
+ tokenizer: DartTokenizer,
92
  prompt: str,
93
  ):
 
 
 
 
 
 
 
94
  output = model.generate(
95
+ get_generation_config(
96
+ prompt,
97
+ tokenizer=tokenizer,
98
+ temperature=1,
99
+ top_p=0.9,
100
+ top_k=100,
101
+ max_new_tokens=256,
102
+ ),
103
  )
104
 
105
+ return output
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  class V2UI:
109
  model_name: str | None = None
110
+ model: V2Model
111
+ tokenizer: DartTokenizer
112
 
113
  input_components: list[Component] = []
114
  generate_btn: gr.Button
 
120
  character_tags: str,
121
  general_tags: str,
122
  rating_option: str,
123
+ aspect_ratio_option: str,
124
  length_option: str,
125
  identity_option: str,
126
+ # image_size: str, # this is from image generation config
127
  *args,
128
  ) -> UpsamplingOutput:
129
  if self.model_name is None or self.model_name != model_name:
130
+ models = prepare_models(ALL_MODELS[model_name])
131
  self.model = models["model"]
132
  self.tokenizer = models["tokenizer"]
133
  self.model_name = model_name
134
 
135
  # normalize tags
136
+ # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
137
+ # character_tags = normalize_tags(self.tokenizer, character_tags)
138
+ # general_tags = normalize_tags(self.tokenizer, general_tags)
139
 
140
  rating_tag = RATING_OPTIONS[rating_option]
141
+ aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
142
  length_tag = LENGTH_OPTIONS[length_option]
143
  identity_tag = IDENTITY_OPTIONS[identity_option]
144
 
 
193
  choices=list(RATING_OPTIONS.keys()),
194
  value="general",
195
  )
196
+ input_aspect_ratio = gr.Radio(
197
+ label="Aspect ratio",
198
+ choices=["ultra_wide", "wide", "square", "tall", "ultra_tall"],
199
+ value="tall",
200
+ )
201
  input_length = gr.Radio(
202
  label="Length",
203
  choices=list(LENGTH_OPTIONS.keys()),
 
223
  input_character,
224
  input_general,
225
  input_rating,
226
+ input_aspect_ratio,
227
  input_length,
228
  input_identity,
229
  ]