Kohaku-Blueleaf commited on
Commit
7d4afe8
β€’
1 Parent(s): 9df83e1

first commit

Browse files
Files changed (5) hide show
  1. app.py +152 -0
  2. diff.py +107 -0
  3. dtg.py +92 -0
  4. meta.py +37 -0
  5. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from time import time_ns
3
+
4
+ import torch
5
+
6
+ import spaces
7
+ import gradio as gr
8
+ from transformers import set_seed
9
+
10
+ from kgen import models
11
+ from diff import load_model, encode_prompts
12
+ from dtg import process
13
+ from meta import (
14
+ DEFAULT_STYLE_LIST,
15
+ MODEL_FORMAT_LIST,
16
+ MODEL_DEFAULT_QUALITY_LIST,
17
+ DEFAULT_NEGATIVE_PROMPT,
18
+ )
19
+
20
+
21
+ sdxl_pipe = load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda")
22
+ models.load_model(models.model_list[0])
23
+ models.text_model.cuda()
24
+
25
+ current_dtg_model = models.model_list[0]
26
+ current_sdxl_model = "KBlueLeaf/Kohaku-XL-Epsilon"
27
+
28
+
29
+ @spaces.GPU
30
+ def gen(
31
+ sdxl_model: str,
32
+ dtg_model: str,
33
+ style: str,
34
+ base_prompt: str,
35
+ addon_prompt: str = "",
36
+ ):
37
+ global current_dtg_model, current_sdxl_model, sdxl_pipe
38
+ if sdxl_model != current_sdxl_model:
39
+ sdxl_pipe = load_model(model_id=sdxl_model, device="cuda")
40
+ current_sdxl_model = sdxl_model
41
+ if dtg_model != current_dtg_model:
42
+ models.load_model(dtg_model)
43
+ models.text_model.cuda()
44
+ current_dtg_model = dtg_model
45
+
46
+ t0 = time_ns()
47
+ seed = random.randint(0, 2**31 - 1)
48
+
49
+ prompt = (
50
+ f"{base_prompt}, {addon_prompt}, "
51
+ f"{DEFAULT_STYLE_LIST[style]}, "
52
+ f"{MODEL_DEFAULT_QUALITY_LIST[sdxl_model]}, "
53
+ )
54
+ full_prompt = process(
55
+ prompt,
56
+ aspect_ratio=1.0,
57
+ seed=seed,
58
+ tag_length="short",
59
+ ban_tags=".*alternate.*, character doll, multiple.*, .*cosplay.*, .*name, .*text.*",
60
+ format=MODEL_FORMAT_LIST[sdxl_model],
61
+ temperature=1.2,
62
+ )
63
+ torch.cuda.empty_cache()
64
+
65
+ prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
66
+ encode_prompts(sdxl_pipe, full_prompt, DEFAULT_NEGATIVE_PROMPT)
67
+ )
68
+ set_seed(seed)
69
+ with torch.autocast("cuda"):
70
+ result = sdxl_pipe(
71
+ prompt_embeds=prompt_embeds,
72
+ negative_prompt_embeds=negative_prompt_embeds,
73
+ pooled_prompt_embeds=pooled_embeds2,
74
+ negative_pooled_prompt_embeds=neg_pooled_embeds2,
75
+ num_inference_steps=24,
76
+ width=1024,
77
+ height=1024,
78
+ guidance_scale=6.0,
79
+ ).images[0]
80
+ torch.cuda.empty_cache()
81
+ t1 = time_ns()
82
+
83
+ return result.convert("RGB"), full_prompt, f"Cost: {(t1 - t0) / 1e9:.2}sec"
84
+
85
+
86
+ if __name__ == "__main__":
87
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
+ gr.Markdown("""# This Cute Dragon Girl Doesn't Exist""")
89
+ with gr.Accordion("Introduction and Instructions", open=False):
90
+ gr.Markdown(
91
+ """
92
+ ### What is this:
93
+ "This Cute Dragon Girl Doesn't Exist" is a Demo for KGen System(DanTagGen) with SDXL anime models.
94
+ It is aimed to show how the DanTagGen can be used to "refine/upsample" simple prompt to help the T2I model.
95
+
96
+ Since I already have some application and demo on DanTagGen.
97
+ This demo is designed to be more "simple" than before.
98
+
99
+ Just one click, and get the result with high quality and high diversity.
100
+
101
+ ### How to use it:
102
+ click "Next" button until you get the dragon girl you like.
103
+
104
+ ### Resources:
105
+ - My anime model: [Kohaku XL Epsilon](https://huggingface.co/KBlueLeaf/Kohaku-XL-Epsilon)
106
+ - DanTagGen: [DanTagGen](https://huggingface.co/KBlueLeaf/DanTagGen-beta)
107
+ - DanTagGen extension: [z-a1111-sd-webui-dtg](https://github.com/KohakuBlueleaf/z-a1111-sd-webui-dtg)
108
+ """
109
+ )
110
+ with gr.Row():
111
+ with gr.Column(scale=3):
112
+ with gr.Row():
113
+ sdxl_model = gr.Dropdown(
114
+ MODEL_FORMAT_LIST,
115
+ label="SDXL Model",
116
+ value=list(MODEL_FORMAT_LIST)[0],
117
+ )
118
+ dtg_model = gr.Dropdown(
119
+ models.model_list,
120
+ label="DTG Model",
121
+ value=models.model_list[0],
122
+ )
123
+ base_prompt = gr.Textbox(
124
+ label="Base prompt",
125
+ lines=1,
126
+ value="1girl, solo, dragon girl, dragon wings, dragon horns, dragon tail",
127
+ interactive=False,
128
+ )
129
+ with gr.Row():
130
+ addon_propmt = gr.Textbox(
131
+ label="Addon prompt",
132
+ lines=1,
133
+ value="cowboy shot, loli",
134
+ )
135
+ style = gr.Dropdown(
136
+ DEFAULT_STYLE_LIST,
137
+ label="Style",
138
+ value=list(DEFAULT_STYLE_LIST)[0],
139
+ )
140
+ submit = gr.Button("Next")
141
+ dtg_output = gr.TextArea(label="DTG output", lines=9, show_copy_button=True)
142
+ cost_time = gr.Markdown()
143
+ with gr.Column(scale=4):
144
+ result = gr.Image(label="Result", type="numpy", interactive=False)
145
+
146
+ submit.click(
147
+ fn=gen,
148
+ inputs=[sdxl_model, dtg_model, style, base_prompt, addon_propmt],
149
+ outputs=[result, dtg_output, cost_time],
150
+ )
151
+
152
+ demo.launch()
diff.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionXLKDiffusionPipeline
5
+ from k_diffusion.sampling import get_sigmas_polyexponential
6
+ from k_diffusion.sampling import sample_dpmpp_2m_sde
7
+
8
+
9
+ def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
10
+ self.num_inference_steps = num_inference_steps
11
+
12
+ self.sigmas = get_sigmas_polyexponential(
13
+ num_inference_steps + 1,
14
+ sigma_min=orig_sigmas[-2],
15
+ sigma_max=orig_sigmas[0],
16
+ rho=0.666666,
17
+ device=device or "cpu",
18
+ )
19
+ self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])
20
+
21
+
22
+ def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
23
+ pipe: StableDiffusionXLKDiffusionPipeline
24
+ pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
25
+ model_id, torch_dtype=torch.float16
26
+ ).to(device)
27
+ pipe.scheduler.set_timesteps = partial(
28
+ set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
29
+ )
30
+ pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
31
+ return pipe
32
+
33
+
34
+ def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt):
35
+ max_length = pipe.tokenizer.model_max_length
36
+
37
+ input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
38
+ input_ids2 = pipe.tokenizer_2(prompt, return_tensors="pt").input_ids.to("cuda")
39
+
40
+ negative_ids = pipe.tokenizer(
41
+ neg_prompt,
42
+ truncation=False,
43
+ padding="max_length",
44
+ max_length=input_ids.shape[-1],
45
+ return_tensors="pt",
46
+ ).input_ids.to("cuda")
47
+ negative_ids2 = pipe.tokenizer_2(
48
+ neg_prompt,
49
+ truncation=False,
50
+ padding="max_length",
51
+ max_length=input_ids.shape[-1],
52
+ return_tensors="pt",
53
+ ).input_ids.to("cuda")
54
+
55
+ if negative_ids.size() > input_ids.size():
56
+ input_ids = pipe.tokenizer(
57
+ prompt,
58
+ truncation=False,
59
+ padding="max_length",
60
+ max_length=negative_ids.shape[-1],
61
+ return_tensors="pt",
62
+ ).input_ids.to("cuda")
63
+ input_ids2 = pipe.tokenizer_2(
64
+ prompt,
65
+ truncation=False,
66
+ padding="max_length",
67
+ max_length=negative_ids.shape[-1],
68
+ return_tensors="pt",
69
+ ).input_ids.to("cuda")
70
+
71
+ concat_embeds = []
72
+ neg_embeds = []
73
+ for i in range(0, input_ids.shape[-1], max_length):
74
+ concat_embeds.append(pipe.text_encoder(input_ids[:, i : i + max_length])[0])
75
+ neg_embeds.append(pipe.text_encoder(negative_ids[:, i : i + max_length])[0])
76
+
77
+ concat_embeds2 = []
78
+ neg_embeds2 = []
79
+ pooled_embeds2 = []
80
+ neg_pooled_embeds2 = []
81
+ for i in range(0, input_ids.shape[-1], max_length):
82
+ hidden_states = pipe.text_encoder_2(
83
+ input_ids2[:, i : i + max_length], output_hidden_states=True
84
+ )
85
+ concat_embeds2.append(hidden_states.hidden_states[-2])
86
+ pooled_embeds2.append(hidden_states[0])
87
+
88
+ hidden_states = pipe.text_encoder_2(
89
+ negative_ids2[:, i : i + max_length], output_hidden_states=True
90
+ )
91
+ neg_embeds2.append(hidden_states.hidden_states[-2])
92
+ neg_pooled_embeds2.append(hidden_states[0])
93
+
94
+ prompt_embeds = torch.cat(concat_embeds, dim=1)
95
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
96
+ prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
97
+ negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1)
98
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
99
+ negative_prompt_embeds = torch.cat(
100
+ [negative_prompt_embeds, negative_prompt_embeds2], dim=-1
101
+ )
102
+
103
+ pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
104
+ neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)
105
+
106
+ return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2
107
+
dtg.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import pathlib
3
+
4
+ import kgen.models as models
5
+ from kgen.formatter import seperate_tags, apply_format, apply_dtg_prompt
6
+ from kgen.metainfo import TARGET
7
+ from kgen.generate import tag_gen
8
+ from kgen.logging import logger
9
+
10
+
11
+ SEED_MAX = 2**31 - 1
12
+ DEFAULT_FORMAT = """<|special|>,
13
+ <|characters|>, <|copyrights|>,
14
+ <|artist|>,
15
+
16
+ <|general|>,
17
+
18
+ <|quality|>, <|meta|>, <|rating|>"""
19
+
20
+
21
+ def process(
22
+ prompt: str,
23
+ aspect_ratio: float,
24
+ seed: int,
25
+ tag_length: str,
26
+ ban_tags: str,
27
+ format: str,
28
+ temperature: float,
29
+ ):
30
+ propmt_preview = prompt.replace("\n", " ")[:40]
31
+ logger.info(f"Processing propmt: {propmt_preview}...")
32
+ logger.info(f"Processing with seed: {seed}")
33
+ black_list = [tag.strip() for tag in ban_tags.split(",") if tag.strip()]
34
+ all_tags = [tag.strip() for tag in prompt.strip().split(",") if tag.strip()]
35
+
36
+ tag_length = tag_length.replace(" ", "_")
37
+ len_target = TARGET[tag_length]
38
+
39
+ tag_map = seperate_tags(all_tags)
40
+ dtg_prompt = apply_dtg_prompt(tag_map, tag_length, aspect_ratio)
41
+ for _, extra_tokens, iter_count in tag_gen(
42
+ models.text_model,
43
+ models.tokenizer,
44
+ dtg_prompt,
45
+ tag_map["special"] + tag_map["general"],
46
+ len_target,
47
+ black_list,
48
+ temperature=temperature,
49
+ top_p=0.8,
50
+ top_k=80,
51
+ max_new_tokens=512,
52
+ max_retry=10,
53
+ max_same_output=5,
54
+ seed=seed % SEED_MAX,
55
+ ):
56
+ pass
57
+ tag_map["general"] += extra_tokens
58
+ prompt_by_dtg = apply_format(tag_map, format)
59
+ logger.info(
60
+ "Prompt processing done. General Tags Count: "
61
+ f"{len(tag_map['general'] + tag_map['special'])}"
62
+ f" | Total iterations: {iter_count}"
63
+ )
64
+ return prompt_by_dtg
65
+
66
+
67
+ if __name__ == "__main__":
68
+ models.model_dir = pathlib.Path(__file__).parent / "models"
69
+
70
+ file = models.download_gguf()
71
+ files = models.list_gguf()
72
+ file = files[-1]
73
+ logger.info(f"Use gguf model from local file: {file}")
74
+ models.load_model(file, gguf=True)
75
+
76
+ prompt = """
77
+ 1girl, ask (askzy), masterpiece
78
+ """
79
+
80
+ t0 = time.time_ns()
81
+ result = process(
82
+ prompt,
83
+ aspect_ratio=1.0,
84
+ seed=1,
85
+ tag_length="long",
86
+ ban_tags="",
87
+ format=DEFAULT_FORMAT,
88
+ temperature=1.35,
89
+ )
90
+ t1 = time.time_ns()
91
+ logger.info(f"Result:\n{result}")
92
+ logger.info(f"Time cost: {(t1 - t0) / 10**6:.1f}ms")
meta.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_STYLE_LIST = {
2
+ "style 1": "ask (askzy), torino aqua, migolu",
3
+ "style 2": "azuuru, torino aqua, kedama milk, fuzichoco, ask (askzy), chen bin, atdan, hito, mignon",
4
+ "style 3": "nou (nounknown), shikimi (yurakuru), namiki itsuki, lemon89h, satsuki (miicat), chon (chon33v), omutatsu, mochizuki kei",
5
+ "style 4": "ciloranko, maccha (mochancc), lobelia (saclia), migolu, ask (askzy), wanke, jiu ye sang, rumoon, mizumi zumi",
6
+ "style 5": "reoen, alchemaniac, rella, watercolor (medium)",
7
+ "no style": "",
8
+ }
9
+
10
+ MODEL_DEFAULT_QUALITY_LIST = {
11
+ "KBlueLeaf/Kohaku-XL-Epsilon": "masterpiece, newest, absurdres, safe",
12
+ "cagliostrolab/animagine-xl-3.1": "masterpiece, newest, very aesthetic, absurdres, safe",
13
+ }
14
+
15
+ MODEL_FORMAT_LIST = {
16
+ "KBlueLeaf/Kohaku-XL-Epsilon": """<|special|>,
17
+ <|characters|>, <|copyrights|>,
18
+ <|artist|>,
19
+
20
+ <|general|>,
21
+
22
+ <|quality|>, <|meta|>, <|rating|>""",
23
+ "cagliostrolab/animagine-xl-3.1": """<|special|>,
24
+ <|characters|>, <|copyrights|>,
25
+ <|artist|>,
26
+
27
+ <|general|>,
28
+
29
+ <|quality|>, <|meta|>, <|rating|>""",
30
+ }
31
+
32
+
33
+ DEFAULT_NEGATIVE_PROMPT = """
34
+ low quality, worst quality, normal quality, text, signature, jpeg artifacts,
35
+ bad anatomy, old, early, mini skirt, nsfw, chibi, multiple girls, multiple boys,
36
+ multiple tails, multiple views, copyright name, watermark, artist name, signature
37
+ """
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ k_diffusion
4
+ requests
5
+ sentencepiece
6
+ tipo-kgen
7
+ spaces