Spaces:
Running
on
Zero
Running
on
Zero
feat: dartrs backend
Browse files- app.py +64 -75
- diffusion.py +4 -1
- lpw_pipeline_xl.py +0 -0
- requirements.txt +2 -1
- utils.py +15 -0
- v2.py +57 -76
app.py
CHANGED
@@ -6,17 +6,28 @@ import gradio as gr
|
|
6 |
from v2 import V2UI
|
7 |
from diffusion import ImageGenerator
|
8 |
from output import UpsamplingOutput
|
9 |
-
from utils import QUALITY_TAGS, NEGATIVE_PROMPT, IMAGE_SIZE_OPTIONS,
|
10 |
|
11 |
|
12 |
def animagine_xl_v3_1(output: UpsamplingOutput):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
return ", ".join(
|
14 |
[
|
15 |
part.strip()
|
16 |
for part in [
|
|
|
17 |
output.character_tags,
|
18 |
output.copyright_tags,
|
19 |
-
|
20 |
output.upsampled_tags,
|
21 |
(
|
22 |
output.rating_tag
|
@@ -35,59 +46,29 @@ def elapsed_time_format(elapsed_time: float) -> str:
|
|
35 |
|
36 |
def parse_upsampling_output(
|
37 |
upsampler: Callable[..., UpsamplingOutput],
|
38 |
-
image_generator: Callable[..., Image.Image],
|
39 |
):
|
40 |
-
def _parse_upsampling_output(
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
output = upsampler(*args)
|
44 |
|
45 |
print(output)
|
46 |
|
47 |
-
if not generate_image:
|
48 |
-
return (
|
49 |
-
animagine_xl_v3_1(output),
|
50 |
-
elapsed_time_format(output.elapsed_time),
|
51 |
-
None,
|
52 |
-
)
|
53 |
-
|
54 |
-
# generate image
|
55 |
-
[
|
56 |
-
image_size_option,
|
57 |
-
quality_tags,
|
58 |
-
negative_prompt,
|
59 |
-
num_inference_steps,
|
60 |
-
guidance_scale,
|
61 |
-
] = args[
|
62 |
-
7:
|
63 |
-
] # remove the first 7 arguments for upsampler
|
64 |
-
width, height = IMAGE_SIZES[image_size_option]
|
65 |
-
image = image_generator(
|
66 |
-
", ".join([animagine_xl_v3_1(output), quality_tags]),
|
67 |
-
negative_prompt,
|
68 |
-
height,
|
69 |
-
width,
|
70 |
-
num_inference_steps,
|
71 |
-
guidance_scale,
|
72 |
-
)
|
73 |
-
|
74 |
return (
|
75 |
animagine_xl_v3_1(output),
|
76 |
elapsed_time_format(output.elapsed_time),
|
77 |
-
|
|
|
|
|
78 |
)
|
79 |
|
80 |
return _parse_upsampling_output
|
81 |
|
82 |
|
83 |
-
def toggle_visible_output_image(generate_image: bool):
|
84 |
-
return gr.update(
|
85 |
-
visible=generate_image,
|
86 |
-
)
|
87 |
-
|
88 |
-
|
89 |
def image_generation_config_ui():
|
90 |
-
with gr.Accordion(label="Image generation config", open=
|
91 |
image_size = gr.Radio(
|
92 |
label="Image size",
|
93 |
choices=list(IMAGE_SIZE_OPTIONS.keys()),
|
@@ -142,7 +123,7 @@ def main():
|
|
142 |
v2 = V2UI()
|
143 |
|
144 |
print("Loading diffusion model...")
|
145 |
-
image_generator = ImageGenerator()
|
146 |
print("Loaded.")
|
147 |
|
148 |
with gr.Blocks() as ui:
|
@@ -152,25 +133,25 @@ def main():
|
|
152 |
with gr.Column():
|
153 |
v2.ui()
|
154 |
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
)
|
158 |
|
159 |
accordion, image_generation_config_components = (
|
160 |
image_generation_config_ui()
|
161 |
)
|
162 |
|
163 |
-
with gr.Column():
|
164 |
-
output_text = gr.TextArea(label="Output tags", interactive=False)
|
165 |
-
|
166 |
-
elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
|
167 |
-
|
168 |
output_image = gr.Gallery(
|
169 |
label="Output image",
|
170 |
columns=1,
|
171 |
preview=True,
|
172 |
show_label=False,
|
173 |
-
visible=
|
174 |
)
|
175 |
|
176 |
gr.Examples(
|
@@ -179,78 +160,86 @@ def main():
|
|
179 |
"original",
|
180 |
"",
|
181 |
"1girl, solo, blue theme, limited palette",
|
182 |
-
"
|
|
|
183 |
"long",
|
184 |
-
"
|
185 |
],
|
186 |
[
|
187 |
"",
|
188 |
"",
|
189 |
"4girls",
|
190 |
-
"
|
|
|
191 |
"very_long",
|
192 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
],
|
194 |
[
|
195 |
"",
|
196 |
"",
|
197 |
"no humans, scenery, spring (season)",
|
198 |
-
"
|
|
|
199 |
"medium",
|
200 |
-
"
|
201 |
],
|
202 |
[
|
203 |
"sousou no frieren",
|
204 |
"frieren",
|
205 |
"1girl, solo",
|
206 |
-
"
|
|
|
207 |
"long",
|
208 |
-
"
|
209 |
],
|
210 |
[
|
211 |
"honkai: star rail",
|
212 |
"silver wolf (honkai: star rail)",
|
213 |
"1girl, solo, annoyed",
|
214 |
-
"
|
|
|
215 |
"long",
|
216 |
-
"
|
217 |
],
|
218 |
[
|
219 |
"bocchi the rock!",
|
220 |
"gotoh hitori, kita ikuyo, ijichi nijika, yamada ryo",
|
221 |
"4girls, multiple girls",
|
222 |
-
"
|
|
|
223 |
"very_long",
|
224 |
-
"
|
225 |
],
|
226 |
[
|
227 |
"chuunibyou demo koi ga shitai!",
|
228 |
"takanashi rikka",
|
229 |
"1girl, solo",
|
230 |
-
"
|
|
|
231 |
"long",
|
232 |
-
"
|
233 |
],
|
234 |
],
|
235 |
inputs=[
|
236 |
-
*v2.get_inputs()[1:
|
237 |
-
image_generation_config_components[0], # image size
|
238 |
],
|
239 |
)
|
240 |
|
241 |
v2.get_generate_btn().click(
|
242 |
-
parse_upsampling_output(v2.on_generate
|
243 |
inputs=[
|
244 |
-
generate_image_check,
|
245 |
*v2.get_inputs(),
|
246 |
-
*image_generation_config_components,
|
247 |
],
|
248 |
-
outputs=[output_text, elapsed_time_md,
|
249 |
-
)
|
250 |
-
generate_image_check.change(
|
251 |
-
toggle_visible_output_image,
|
252 |
-
inputs=[generate_image_check],
|
253 |
-
outputs=[output_image],
|
254 |
)
|
255 |
|
256 |
ui.launch()
|
|
|
6 |
from v2 import V2UI
|
7 |
from diffusion import ImageGenerator
|
8 |
from output import UpsamplingOutput
|
9 |
+
from utils import QUALITY_TAGS, NEGATIVE_PROMPT, IMAGE_SIZE_OPTIONS, PEOPLE_TAGS
|
10 |
|
11 |
|
12 |
def animagine_xl_v3_1(output: UpsamplingOutput):
|
13 |
+
# separate people tags (e.g. 1girl)
|
14 |
+
people_tags = []
|
15 |
+
other_general_tags = []
|
16 |
+
for tag in output.general_tags.split(","):
|
17 |
+
tag = tag.strip()
|
18 |
+
if tag in PEOPLE_TAGS:
|
19 |
+
people_tags.append(tag)
|
20 |
+
else:
|
21 |
+
other_general_tags.append(tag)
|
22 |
+
|
23 |
return ", ".join(
|
24 |
[
|
25 |
part.strip()
|
26 |
for part in [
|
27 |
+
*people_tags,
|
28 |
output.character_tags,
|
29 |
output.copyright_tags,
|
30 |
+
*other_general_tags,
|
31 |
output.upsampled_tags,
|
32 |
(
|
33 |
output.rating_tag
|
|
|
46 |
|
47 |
def parse_upsampling_output(
|
48 |
upsampler: Callable[..., UpsamplingOutput],
|
|
|
49 |
):
|
50 |
+
def _parse_upsampling_output(*args) -> tuple[
|
51 |
+
str,
|
52 |
+
str,
|
53 |
+
dict,
|
54 |
+
]:
|
55 |
output = upsampler(*args)
|
56 |
|
57 |
print(output)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
return (
|
60 |
animagine_xl_v3_1(output),
|
61 |
elapsed_time_format(output.elapsed_time),
|
62 |
+
gr.update(
|
63 |
+
interactive=True,
|
64 |
+
),
|
65 |
)
|
66 |
|
67 |
return _parse_upsampling_output
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def image_generation_config_ui():
|
71 |
+
with gr.Accordion(label="Image generation config", open=False) as accordion:
|
72 |
image_size = gr.Radio(
|
73 |
label="Image size",
|
74 |
choices=list(IMAGE_SIZE_OPTIONS.keys()),
|
|
|
123 |
v2 = V2UI()
|
124 |
|
125 |
print("Loading diffusion model...")
|
126 |
+
# image_generator = ImageGenerator()
|
127 |
print("Loaded.")
|
128 |
|
129 |
with gr.Blocks() as ui:
|
|
|
133 |
with gr.Column():
|
134 |
v2.ui()
|
135 |
|
136 |
+
with gr.Column():
|
137 |
+
output_text = gr.TextArea(label="Output tags", interactive=False)
|
138 |
+
|
139 |
+
elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
|
140 |
+
|
141 |
+
generate_image_btn = gr.Button(
|
142 |
+
value="Generate image with this prompt!",
|
143 |
)
|
144 |
|
145 |
accordion, image_generation_config_components = (
|
146 |
image_generation_config_ui()
|
147 |
)
|
148 |
|
|
|
|
|
|
|
|
|
|
|
149 |
output_image = gr.Gallery(
|
150 |
label="Output image",
|
151 |
columns=1,
|
152 |
preview=True,
|
153 |
show_label=False,
|
154 |
+
visible=False,
|
155 |
)
|
156 |
|
157 |
gr.Examples(
|
|
|
160 |
"original",
|
161 |
"",
|
162 |
"1girl, solo, blue theme, limited palette",
|
163 |
+
"sfw",
|
164 |
+
"ultra_wide",
|
165 |
"long",
|
166 |
+
"lax",
|
167 |
],
|
168 |
[
|
169 |
"",
|
170 |
"",
|
171 |
"4girls",
|
172 |
+
"sfw",
|
173 |
+
"tall",
|
174 |
"very_long",
|
175 |
+
"lax",
|
176 |
+
],
|
177 |
+
[
|
178 |
+
"original",
|
179 |
+
"",
|
180 |
+
"1girl, solo, upper body, looking at viewer, profile picture",
|
181 |
+
"sfw",
|
182 |
+
"square",
|
183 |
+
"medium",
|
184 |
+
"none",
|
185 |
],
|
186 |
[
|
187 |
"",
|
188 |
"",
|
189 |
"no humans, scenery, spring (season)",
|
190 |
+
"general",
|
191 |
+
"ultra_wide",
|
192 |
"medium",
|
193 |
+
"lax",
|
194 |
],
|
195 |
[
|
196 |
"sousou no frieren",
|
197 |
"frieren",
|
198 |
"1girl, solo",
|
199 |
+
"general",
|
200 |
+
"tall",
|
201 |
"long",
|
202 |
+
"lax",
|
203 |
],
|
204 |
[
|
205 |
"honkai: star rail",
|
206 |
"silver wolf (honkai: star rail)",
|
207 |
"1girl, solo, annoyed",
|
208 |
+
"sfw",
|
209 |
+
"tall",
|
210 |
"long",
|
211 |
+
"lax",
|
212 |
],
|
213 |
[
|
214 |
"bocchi the rock!",
|
215 |
"gotoh hitori, kita ikuyo, ijichi nijika, yamada ryo",
|
216 |
"4girls, multiple girls",
|
217 |
+
"sfw",
|
218 |
+
"ultra_wide",
|
219 |
"very_long",
|
220 |
+
"lax",
|
221 |
],
|
222 |
[
|
223 |
"chuunibyou demo koi ga shitai!",
|
224 |
"takanashi rikka",
|
225 |
"1girl, solo",
|
226 |
+
"sfw",
|
227 |
+
"ultra_tall",
|
228 |
"long",
|
229 |
+
"lax",
|
230 |
],
|
231 |
],
|
232 |
inputs=[
|
233 |
+
*v2.get_inputs()[1:8],
|
|
|
234 |
],
|
235 |
)
|
236 |
|
237 |
v2.get_generate_btn().click(
|
238 |
+
parse_upsampling_output(v2.on_generate),
|
239 |
inputs=[
|
|
|
240 |
*v2.get_inputs(),
|
|
|
241 |
],
|
242 |
+
outputs=[output_text, elapsed_time_md, generate_image_btn],
|
|
|
|
|
|
|
|
|
|
|
243 |
)
|
244 |
|
245 |
ui.launch()
|
diffusion.py
CHANGED
@@ -22,6 +22,9 @@ except ImportError:
|
|
22 |
from utils import NEGATIVE_PROMPT
|
23 |
|
24 |
|
|
|
|
|
|
|
25 |
class ImageGenerator:
|
26 |
pipe: StableDiffusionXLPipeline
|
27 |
|
@@ -41,7 +44,7 @@ class ImageGenerator:
|
|
41 |
# sdpa
|
42 |
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
|
43 |
|
44 |
-
self.pipe.to(
|
45 |
|
46 |
try:
|
47 |
self.pipe = torch.compile(self.pipe)
|
|
|
22 |
from utils import NEGATIVE_PROMPT
|
23 |
|
24 |
|
25 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
26 |
+
|
27 |
+
|
28 |
class ImageGenerator:
|
29 |
pipe: StableDiffusionXLPipeline
|
30 |
|
|
|
44 |
# sdpa
|
45 |
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
|
46 |
|
47 |
+
self.pipe.to(device)
|
48 |
|
49 |
try:
|
50 |
self.pipe = torch.compile(self.pipe)
|
lpw_pipeline_xl.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ accelerate==0.29.2
|
|
3 |
transformers==4.38.2
|
4 |
optimum[onnxruntime]==1.19.1
|
5 |
diffusers==0.27.2
|
6 |
-
spaces==0.26.2
|
|
|
|
3 |
transformers==4.38.2
|
4 |
optimum[onnxruntime]==1.19.1
|
5 |
diffusers==0.27.2
|
6 |
+
spaces==0.26.2
|
7 |
+
git+https://github.com/p1atdev/dartrs.git@33cdcfe77f236ba286ad60e10db8a5650e150fd2
|
utils.py
CHANGED
@@ -22,6 +22,13 @@ IMAGE_SIZES = {
|
|
22 |
"640x1536": (640, 1536),
|
23 |
}
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
RATING_OPTIONS = {
|
26 |
"sfw": "<|rating:sfw|>",
|
27 |
"general": "<|rating:general|>",
|
@@ -42,3 +49,11 @@ IDENTITY_OPTIONS = {
|
|
42 |
"lax": "<|identity:lax|>",
|
43 |
"strict": "<|identity:strict|>",
|
44 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
"640x1536": (640, 1536),
|
23 |
}
|
24 |
|
25 |
+
ASPECT_RATIO_OPTIONS = {
|
26 |
+
"ultra_wide": "<|aspect_ratio:ultra_wide|>",
|
27 |
+
"wide": "<|aspect_ratio:wide|>",
|
28 |
+
"square": "<|aspect_ratio:square|>",
|
29 |
+
"tall": "<|aspect_ratio:tall|>",
|
30 |
+
"ultra_tall": "<|aspect_ratio:ultra_tall|>",
|
31 |
+
}
|
32 |
RATING_OPTIONS = {
|
33 |
"sfw": "<|rating:sfw|>",
|
34 |
"general": "<|rating:general|>",
|
|
|
49 |
"lax": "<|identity:lax|>",
|
50 |
"strict": "<|identity:strict|>",
|
51 |
}
|
52 |
+
|
53 |
+
|
54 |
+
PEOPLE_TAGS = [
|
55 |
+
*[f"1{x}" for x in ["girl", "boy", "other"]],
|
56 |
+
*[f"{i}girls" for i in range(2, 6)],
|
57 |
+
*[f"6+{x}s" for x in ["girl", "boy", "other"]],
|
58 |
+
"no humans",
|
59 |
+
]
|
v2.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
import time
|
2 |
|
3 |
import torch
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
from gradio.components import Component
|
@@ -16,31 +20,26 @@ except ImportError:
|
|
16 |
|
17 |
|
18 |
from output import UpsamplingOutput
|
19 |
-
from utils import
|
20 |
|
21 |
ALL_MODELS = {
|
22 |
-
"dart-v2-
|
23 |
-
"repo": "p1atdev/dart-v2-
|
24 |
-
"type": "sft",
|
25 |
-
},
|
26 |
-
"dart-v2-mistral-100m-sft": {
|
27 |
-
"repo": "p1atdev/dart-v2-mistral-100m-sft",
|
28 |
"type": "sft",
|
|
|
29 |
},
|
30 |
-
"dart-v2-mixtral-160m-sft": {
|
31 |
-
"repo": "p1atdev/dart-v2-mixtral-160m-sft",
|
32 |
"type": "sft",
|
|
|
33 |
},
|
34 |
}
|
35 |
|
36 |
|
37 |
-
def prepare_models(
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
torch_dtype=torch.bfloat16,
|
42 |
-
device_map="auto",
|
43 |
-
)
|
44 |
|
45 |
return {
|
46 |
"tokenizer": tokenizer,
|
@@ -48,21 +47,21 @@ def prepare_models(model_name: str):
|
|
48 |
}
|
49 |
|
50 |
|
51 |
-
def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
|
67 |
|
68 |
def compose_prompt(
|
@@ -88,46 +87,28 @@ def compose_prompt(
|
|
88 |
@torch.no_grad()
|
89 |
@spaces.GPU(duration=5)
|
90 |
def generate_tags(
|
91 |
-
model,
|
92 |
-
tokenizer:
|
93 |
prompt: str,
|
94 |
):
|
95 |
-
print( # debug
|
96 |
-
tokenizer.tokenize(
|
97 |
-
prompt,
|
98 |
-
add_special_tokens=False,
|
99 |
-
)
|
100 |
-
)
|
101 |
-
input_ids = tokenizer.encode_plus(prompt, return_tensors="pt").input_ids
|
102 |
output = model.generate(
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
)
|
112 |
|
113 |
-
|
114 |
-
pure_output_ids = output[0][len(input_ids[0]) :]
|
115 |
-
|
116 |
-
return ", ".join(
|
117 |
-
[
|
118 |
-
token
|
119 |
-
for token in tokenizer.batch_decode(
|
120 |
-
pure_output_ids, skip_special_tokens=True
|
121 |
-
)
|
122 |
-
if token.strip() != ""
|
123 |
-
]
|
124 |
-
)
|
125 |
|
126 |
|
127 |
class V2UI:
|
128 |
model_name: str | None = None
|
129 |
-
model:
|
130 |
-
tokenizer:
|
131 |
|
132 |
input_components: list[Component] = []
|
133 |
generate_btn: gr.Button
|
@@ -139,25 +120,25 @@ class V2UI:
|
|
139 |
character_tags: str,
|
140 |
general_tags: str,
|
141 |
rating_option: str,
|
142 |
-
|
143 |
length_option: str,
|
144 |
identity_option: str,
|
145 |
-
image_size: str, # this is from image generation config
|
146 |
*args,
|
147 |
) -> UpsamplingOutput:
|
148 |
if self.model_name is None or self.model_name != model_name:
|
149 |
-
models = prepare_models(ALL_MODELS[model_name]
|
150 |
self.model = models["model"]
|
151 |
self.tokenizer = models["tokenizer"]
|
152 |
self.model_name = model_name
|
153 |
|
154 |
# normalize tags
|
155 |
-
copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
|
156 |
-
character_tags = normalize_tags(self.tokenizer, character_tags)
|
157 |
-
general_tags = normalize_tags(self.tokenizer, general_tags)
|
158 |
|
159 |
rating_tag = RATING_OPTIONS[rating_option]
|
160 |
-
aspect_ratio_tag =
|
161 |
length_tag = LENGTH_OPTIONS[length_option]
|
162 |
identity_tag = IDENTITY_OPTIONS[identity_option]
|
163 |
|
@@ -212,11 +193,11 @@ class V2UI:
|
|
212 |
choices=list(RATING_OPTIONS.keys()),
|
213 |
value="general",
|
214 |
)
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
input_length = gr.Radio(
|
221 |
label="Length",
|
222 |
choices=list(LENGTH_OPTIONS.keys()),
|
@@ -242,7 +223,7 @@ class V2UI:
|
|
242 |
input_character,
|
243 |
input_general,
|
244 |
input_rating,
|
245 |
-
|
246 |
input_length,
|
247 |
input_identity,
|
248 |
]
|
|
|
1 |
import time
|
2 |
|
3 |
import torch
|
4 |
+
|
5 |
+
from dartrs.v2 import V2Model, MixtralModel
|
6 |
+
from dartrs.dartrs import DartTokenizer
|
7 |
+
from dartrs.utils import get_generation_config
|
8 |
+
|
9 |
|
10 |
import gradio as gr
|
11 |
from gradio.components import Component
|
|
|
20 |
|
21 |
|
22 |
from output import UpsamplingOutput
|
23 |
+
from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
|
24 |
|
25 |
ALL_MODELS = {
|
26 |
+
"dart-v2-mixtral-160m-sft-6": {
|
27 |
+
"repo": "p1atdev/dart-v2-mixtral-160m-sft-6",
|
|
|
|
|
|
|
|
|
28 |
"type": "sft",
|
29 |
+
"class": MixtralModel,
|
30 |
},
|
31 |
+
"dart-v2-mixtral-160m-sft-8": {
|
32 |
+
"repo": "p1atdev/dart-v2-mixtral-160m-sft-8",
|
33 |
"type": "sft",
|
34 |
+
"class": MixtralModel,
|
35 |
},
|
36 |
}
|
37 |
|
38 |
|
39 |
+
def prepare_models(model_config: dict):
|
40 |
+
model_name = model_config["repo"]
|
41 |
+
tokenizer = DartTokenizer.from_pretrained(model_name)
|
42 |
+
model = model_config["class"].from_pretrained(model_name)
|
|
|
|
|
|
|
43 |
|
44 |
return {
|
45 |
"tokenizer": tokenizer,
|
|
|
47 |
}
|
48 |
|
49 |
|
50 |
+
# def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
|
51 |
+
# """Just remove unk tokens."""
|
52 |
+
# return ", ".join(
|
53 |
+
# tokenizer.batch_decode(
|
54 |
+
# [
|
55 |
+
# token
|
56 |
+
# for token in tokenizer.encode_plus(
|
57 |
+
# tags.strip(),
|
58 |
+
# return_tensors="pt",
|
59 |
+
# ).input_ids[0]
|
60 |
+
# if int(token) != tokenizer.unk_token_id
|
61 |
+
# ],
|
62 |
+
# skip_special_tokens=True,
|
63 |
+
# )
|
64 |
+
# )
|
65 |
|
66 |
|
67 |
def compose_prompt(
|
|
|
87 |
@torch.no_grad()
|
88 |
@spaces.GPU(duration=5)
|
89 |
def generate_tags(
|
90 |
+
model: V2Model,
|
91 |
+
tokenizer: DartTokenizer,
|
92 |
prompt: str,
|
93 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
output = model.generate(
|
95 |
+
get_generation_config(
|
96 |
+
prompt,
|
97 |
+
tokenizer=tokenizer,
|
98 |
+
temperature=1,
|
99 |
+
top_p=0.9,
|
100 |
+
top_k=100,
|
101 |
+
max_new_tokens=256,
|
102 |
+
),
|
103 |
)
|
104 |
|
105 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
class V2UI:
|
109 |
model_name: str | None = None
|
110 |
+
model: V2Model
|
111 |
+
tokenizer: DartTokenizer
|
112 |
|
113 |
input_components: list[Component] = []
|
114 |
generate_btn: gr.Button
|
|
|
120 |
character_tags: str,
|
121 |
general_tags: str,
|
122 |
rating_option: str,
|
123 |
+
aspect_ratio_option: str,
|
124 |
length_option: str,
|
125 |
identity_option: str,
|
126 |
+
# image_size: str, # this is from image generation config
|
127 |
*args,
|
128 |
) -> UpsamplingOutput:
|
129 |
if self.model_name is None or self.model_name != model_name:
|
130 |
+
models = prepare_models(ALL_MODELS[model_name])
|
131 |
self.model = models["model"]
|
132 |
self.tokenizer = models["tokenizer"]
|
133 |
self.model_name = model_name
|
134 |
|
135 |
# normalize tags
|
136 |
+
# copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
|
137 |
+
# character_tags = normalize_tags(self.tokenizer, character_tags)
|
138 |
+
# general_tags = normalize_tags(self.tokenizer, general_tags)
|
139 |
|
140 |
rating_tag = RATING_OPTIONS[rating_option]
|
141 |
+
aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
|
142 |
length_tag = LENGTH_OPTIONS[length_option]
|
143 |
identity_tag = IDENTITY_OPTIONS[identity_option]
|
144 |
|
|
|
193 |
choices=list(RATING_OPTIONS.keys()),
|
194 |
value="general",
|
195 |
)
|
196 |
+
input_aspect_ratio = gr.Radio(
|
197 |
+
label="Aspect ratio",
|
198 |
+
choices=["ultra_wide", "wide", "square", "tall", "ultra_tall"],
|
199 |
+
value="tall",
|
200 |
+
)
|
201 |
input_length = gr.Radio(
|
202 |
label="Length",
|
203 |
choices=list(LENGTH_OPTIONS.keys()),
|
|
|
223 |
input_character,
|
224 |
input_general,
|
225 |
input_rating,
|
226 |
+
input_aspect_ratio,
|
227 |
input_length,
|
228 |
input_identity,
|
229 |
]
|