Spaces:
Running
on
Zero
Running
on
Zero
Upload 11 files
Browse files- app.py +4 -3
- genimage.py +44 -11
- llmdolphin.py +218 -196
- requirements.txt +4 -5
- tagger/character_series_dict.csv +0 -0
- tagger/danbooru_e621.csv +0 -0
- tagger/tag_group.csv +0 -0
- tagger/tagger.py +556 -0
- tagger/utils.py +50 -0
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
from utils import gradio_copy_text, COPY_ACTION_JS
|
4 |
-
from tagger import convert_danbooru_to_e621_prompt, insert_recom_prompt
|
5 |
from genimage import generate_image
|
6 |
from llmdolphin import (get_llm_formats, get_dolphin_model_format,
|
7 |
get_dolphin_models, get_dolphin_model_info, select_dolphin_model,
|
@@ -59,7 +59,8 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
|
|
59 |
recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
|
60 |
generate_image_btn = gr.Button(value="GENERATE IMAGE", size="lg", variant="primary")
|
61 |
with gr.Row():
|
62 |
-
result_image = gr.Gallery(label="Generated images", columns=1, object_fit="contain", container=True, preview=True,
|
|
|
63 |
with gr.Tab("GGUF-Playground"):
|
64 |
gr.Markdown("""# Chat with lots of Models and LLMs using llama.cpp
|
65 |
This tab is copy of [CaioXapelaum/GGUF-Playground](https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground).<br>
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from tagger.utils import gradio_copy_text, COPY_ACTION_JS
|
4 |
+
from tagger.tagger import convert_danbooru_to_e621_prompt, insert_recom_prompt
|
5 |
from genimage import generate_image
|
6 |
from llmdolphin import (get_llm_formats, get_dolphin_model_format,
|
7 |
get_dolphin_models, get_dolphin_model_info, select_dolphin_model,
|
|
|
59 |
recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
|
60 |
generate_image_btn = gr.Button(value="GENERATE IMAGE", size="lg", variant="primary")
|
61 |
with gr.Row():
|
62 |
+
result_image = gr.Gallery(label="Generated images", columns=1, object_fit="contain", container=True, preview=True, height=512,
|
63 |
+
show_label=False, show_share_button=False, show_download_button=True, interactive=False, visible=True, format="png")
|
64 |
with gr.Tab("GGUF-Playground"):
|
65 |
gr.Markdown("""# Chat with lots of Models and LLMs using llama.cpp
|
66 |
This tab is copy of [CaioXapelaum/GGUF-Playground](https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground).<br>
|
genimage.py
CHANGED
@@ -1,20 +1,49 @@
|
|
1 |
import spaces
|
|
|
2 |
|
3 |
|
4 |
def load_pipeline():
|
5 |
-
from diffusers import
|
6 |
-
import torch
|
7 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
8 |
-
pipe =
|
9 |
"John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
|
10 |
-
|
11 |
-
custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
|
12 |
torch_dtype=torch.float16,
|
13 |
)
|
14 |
pipe.to(device)
|
15 |
return pipe
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def save_image(image, metadata, output_dir):
|
19 |
import os
|
20 |
import uuid
|
@@ -33,26 +62,30 @@ def save_image(image, metadata, output_dir):
|
|
33 |
pipe = load_pipeline()
|
34 |
|
35 |
|
|
|
36 |
@spaces.GPU
|
37 |
def generate_image(prompt, neg_prompt):
|
|
|
|
|
38 |
metadata = {
|
39 |
-
"prompt": prompt
|
40 |
-
"negative_prompt": neg_prompt
|
41 |
"resolution": f"{1024} x {1024}",
|
42 |
"guidance_scale": 7.0,
|
43 |
"num_inference_steps": 28,
|
44 |
"sampler": "Euler",
|
45 |
}
|
46 |
try:
|
|
|
47 |
images = pipe(
|
48 |
-
prompt=prompt
|
49 |
-
negative_prompt=neg_prompt
|
50 |
width=1024,
|
51 |
height=1024,
|
52 |
-
guidance_scale=7.0
|
53 |
num_inference_steps=28,
|
54 |
output_type="pil",
|
55 |
-
|
56 |
).images
|
57 |
if images:
|
58 |
image_paths = [
|
|
|
1 |
import spaces
|
2 |
+
import torch
|
3 |
|
4 |
|
5 |
def load_pipeline():
|
6 |
+
from diffusers import DiffusionPipeline
|
|
|
7 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
8 |
+
pipe = DiffusionPipeline.from_pretrained(
|
9 |
"John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
|
10 |
+
custom_pipeline="lpw_stable_diffusion_xl",
|
11 |
+
#custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
|
12 |
torch_dtype=torch.float16,
|
13 |
)
|
14 |
pipe.to(device)
|
15 |
return pipe
|
16 |
|
17 |
|
18 |
+
def token_auto_concat_embeds(pipe, positive, negative):
|
19 |
+
max_length = pipe.tokenizer.model_max_length
|
20 |
+
positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1]
|
21 |
+
negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1]
|
22 |
+
|
23 |
+
print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.')
|
24 |
+
if max_length < positive_length or max_length < negative_length:
|
25 |
+
print('Concatenated embedding.')
|
26 |
+
if positive_length > negative_length:
|
27 |
+
positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda")
|
28 |
+
negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
|
29 |
+
else:
|
30 |
+
negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda")
|
31 |
+
positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
|
32 |
+
else:
|
33 |
+
positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
|
34 |
+
negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
|
35 |
+
|
36 |
+
positive_concat_embeds = []
|
37 |
+
negative_concat_embeds = []
|
38 |
+
for i in range(0, positive_ids.shape[-1], max_length):
|
39 |
+
positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0])
|
40 |
+
negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
|
41 |
+
|
42 |
+
positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1)
|
43 |
+
negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1)
|
44 |
+
return positive_prompt_embeds, negative_prompt_embeds
|
45 |
+
|
46 |
+
|
47 |
def save_image(image, metadata, output_dir):
|
48 |
import os
|
49 |
import uuid
|
|
|
62 |
pipe = load_pipeline()
|
63 |
|
64 |
|
65 |
+
@torch.inference_mode()
|
66 |
@spaces.GPU
|
67 |
def generate_image(prompt, neg_prompt):
|
68 |
+
prompt += ", anime, masterpiece, best quality, very aesthetic, absurdres"
|
69 |
+
neg_prompt += ", bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast"
|
70 |
metadata = {
|
71 |
+
"prompt": prompt,
|
72 |
+
"negative_prompt": neg_prompt,
|
73 |
"resolution": f"{1024} x {1024}",
|
74 |
"guidance_scale": 7.0,
|
75 |
"num_inference_steps": 28,
|
76 |
"sampler": "Euler",
|
77 |
}
|
78 |
try:
|
79 |
+
#positive_embeds, negative_embeds = token_auto_concat_embeds(pipe, prompt, neg_prompt)
|
80 |
images = pipe(
|
81 |
+
prompt=prompt,
|
82 |
+
negative_prompt=neg_prompt,
|
83 |
width=1024,
|
84 |
height=1024,
|
85 |
+
guidance_scale=7.0,# seg_scale=3.0, seg_applied_layers=["mid"],
|
86 |
num_inference_steps=28,
|
87 |
output_type="pil",
|
88 |
+
clip_skip=1,
|
89 |
).images
|
90 |
if images:
|
91 |
image_paths = [
|
llmdolphin.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
|
|
|
|
|
|
|
|
3 |
from llama_cpp import Llama
|
4 |
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
|
5 |
from llama_cpp_agent.providers import LlamaCppPythonProvider
|
@@ -7,7 +11,6 @@ from llama_cpp_agent.chat_history import BasicChatHistory
|
|
7 |
from llama_cpp_agent.chat_history.messages import Roles
|
8 |
from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
|
9 |
import wrapt_timeout_decorator
|
10 |
-
from pathlib import Path
|
11 |
from llama_cpp_agent.messages_formatter import MessagesFormatter
|
12 |
from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
|
13 |
|
@@ -19,6 +22,7 @@ llm_models = {
|
|
19 |
#"": ["", MessagesFormatterType.OPEN_CHAT],
|
20 |
#"": ["", MessagesFormatterType.CHATML],
|
21 |
#"": ["", MessagesFormatterType.PHI_3],
|
|
|
22 |
"mn-12b-lyra-v2a1-q5_k_m.gguf": ["HalleyStarbun/MN-12B-Lyra-v2a1-Q5_K_M-GGUF", MessagesFormatterType.CHATML],
|
23 |
"L3-8B-Tamamo-v1.i1-Q5_K_M.gguf": ["mradermacher/L3-8B-Tamamo-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
|
24 |
"MN-Chinofun-12B-2.i1-Q4_K_M.gguf": ["mradermacher/MN-Chinofun-12B-2-i1-GGUF", MessagesFormatterType.MISTRAL],
|
@@ -68,6 +72,19 @@ llm_models = {
|
|
68 |
"ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
|
69 |
"ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
|
70 |
"ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
"hermes-llama3-roleplay-1000-v2.Q5_K_M.gguf": ["mradermacher/hermes-llama3-roleplay-1000-v2-GGUF", MessagesFormatterType.LLAMA_3],
|
72 |
"hermes-stheno-8B-v0.1.i1-Q5_K_M.gguf": ["mradermacher/hermes-stheno-8B-v0.1-i1-GGUF", MessagesFormatterType.LLAMA_3],
|
73 |
"qwen-carpmuscle-r-v0.3.Q4_K_M.gguf": ["mradermacher/qwen-carpmuscle-r-v0.3-GGUF", MessagesFormatterType.OPEN_CHAT],
|
@@ -832,6 +849,7 @@ llm_languages = ["English", "Japanese", "Chinese", "Korean", "Spanish", "Portugu
|
|
832 |
llm_models_tupled_list = []
|
833 |
default_llm_model_filename = list(llm_models.keys())[0]
|
834 |
override_llm_format = None
|
|
|
835 |
|
836 |
|
837 |
def to_list(s):
|
@@ -844,7 +862,6 @@ def list_uniq(l):
|
|
844 |
|
845 |
@wrapt_timeout_decorator.timeout(dec_timeout=3.5)
|
846 |
def to_list_ja(s):
|
847 |
-
import re
|
848 |
s = re.sub(r'[、。]', ',', s)
|
849 |
return [x.strip() for x in s.split(",") if not s == ""]
|
850 |
|
@@ -859,7 +876,6 @@ def is_japanese(s):
|
|
859 |
|
860 |
|
861 |
def update_llm_model_tupled_list():
|
862 |
-
from pathlib import Path
|
863 |
global llm_models_tupled_list
|
864 |
llm_models_tupled_list = []
|
865 |
for k, v in llm_models.items():
|
@@ -876,7 +892,6 @@ def update_llm_model_tupled_list():
|
|
876 |
|
877 |
|
878 |
def download_llm_models():
|
879 |
-
from huggingface_hub import hf_hub_download
|
880 |
global llm_models_tupled_list
|
881 |
llm_models_tupled_list = []
|
882 |
for k, v in llm_models.items():
|
@@ -890,7 +905,6 @@ def download_llm_models():
|
|
890 |
|
891 |
|
892 |
def download_llm_model(filename):
|
893 |
-
from huggingface_hub import hf_hub_download
|
894 |
if not filename in llm_models.keys(): return default_llm_model_filename
|
895 |
try:
|
896 |
hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
|
@@ -951,8 +965,6 @@ def get_dolphin_model_format(filename):
|
|
951 |
|
952 |
|
953 |
def add_dolphin_models(query, format_name):
|
954 |
-
import re
|
955 |
-
from huggingface_hub import HfApi
|
956 |
global llm_models
|
957 |
api = HfApi()
|
958 |
add_models = {}
|
@@ -964,20 +976,19 @@ def add_dolphin_models(query, format_name):
|
|
964 |
if s and "" in s: s.remove("")
|
965 |
if len(s) == 1:
|
966 |
repo = s[0]
|
967 |
-
if not api.repo_exists(repo_id = repo): return gr.update(
|
968 |
files = api.list_repo_files(repo_id = repo)
|
969 |
for file in files:
|
970 |
if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
|
971 |
elif len(s) >= 2:
|
972 |
repo = s[0]
|
973 |
filename = s[1]
|
974 |
-
if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update(
|
975 |
add_models[filename] = [repo, format]
|
976 |
-
else: return gr.update(
|
977 |
except Exception as e:
|
978 |
print(e)
|
979 |
-
return gr.update(
|
980 |
-
#print(add_models)
|
981 |
llm_models = (llm_models | add_models).copy()
|
982 |
update_llm_model_tupled_list()
|
983 |
choices = get_dolphin_models()
|
@@ -1177,7 +1188,6 @@ Output should be enclosed in //GENBEGIN//:// and //://GENEND//. The text to be g
|
|
1177 |
|
1178 |
|
1179 |
def get_dolphin_sysprompt():
|
1180 |
-
import re
|
1181 |
prompt = re.sub('<LANGUAGE>', dolphin_output_language, dolphin_system_prompt.get(dolphin_sysprompt_mode, ""))
|
1182 |
return prompt
|
1183 |
|
@@ -1207,11 +1217,11 @@ def select_dolphin_language(lang: str):
|
|
1207 |
|
1208 |
@wrapt_timeout_decorator.timeout(dec_timeout=5.0)
|
1209 |
def get_raw_prompt(msg: str):
|
1210 |
-
import re
|
1211 |
m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
|
1212 |
return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
|
1213 |
|
1214 |
|
|
|
1215 |
@spaces.GPU(duration=60)
|
1216 |
def dolphin_respond(
|
1217 |
message: str,
|
@@ -1225,87 +1235,92 @@ def dolphin_respond(
|
|
1225 |
repeat_penalty: float = 1.1,
|
1226 |
progress=gr.Progress(track_tqdm=True),
|
1227 |
):
|
1228 |
-
|
1229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1230 |
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
flash_attn=True,
|
1239 |
-
n_gpu_layers=81, # 81
|
1240 |
-
n_batch=1024,
|
1241 |
-
n_ctx=8192, #8192
|
1242 |
-
)
|
1243 |
-
provider = LlamaCppPythonProvider(llm)
|
1244 |
-
|
1245 |
-
agent = LlamaCppAgent(
|
1246 |
-
provider,
|
1247 |
-
system_prompt=f"{system_message}",
|
1248 |
-
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
1249 |
-
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
1250 |
-
debug_output=False
|
1251 |
-
)
|
1252 |
-
|
1253 |
-
settings = provider.get_provider_default_settings()
|
1254 |
-
settings.temperature = temperature
|
1255 |
-
settings.top_k = top_k
|
1256 |
-
settings.top_p = top_p
|
1257 |
-
settings.max_tokens = max_tokens
|
1258 |
-
settings.repeat_penalty = repeat_penalty
|
1259 |
-
settings.stream = True
|
1260 |
-
|
1261 |
-
messages = BasicChatHistory()
|
1262 |
-
|
1263 |
-
for msn in history:
|
1264 |
-
user = {
|
1265 |
-
'role': Roles.user,
|
1266 |
-
'content': msn[0]
|
1267 |
-
}
|
1268 |
-
assistant = {
|
1269 |
-
'role': Roles.assistant,
|
1270 |
-
'content': msn[1]
|
1271 |
-
}
|
1272 |
-
messages.add_message(user)
|
1273 |
-
messages.add_message(assistant)
|
1274 |
-
|
1275 |
-
stream = agent.get_chat_response(
|
1276 |
-
message,
|
1277 |
-
llm_sampling_settings=settings,
|
1278 |
-
chat_history=messages,
|
1279 |
-
returns_streaming_generator=True,
|
1280 |
-
print_output=False
|
1281 |
-
)
|
1282 |
-
|
1283 |
-
progress(0.5, desc="Processing...")
|
1284 |
-
|
1285 |
-
outputs = ""
|
1286 |
-
for output in stream:
|
1287 |
-
outputs += output
|
1288 |
-
yield [(outputs, None)]
|
1289 |
|
1290 |
|
1291 |
def dolphin_parse(
|
1292 |
history: list[tuple[str, str]],
|
1293 |
):
|
1294 |
-
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
|
1295 |
-
return "", gr.update(visible=True), gr.update(visible=True)
|
1296 |
try:
|
|
|
|
|
1297 |
msg = history[-1][0]
|
1298 |
raw_prompt = get_raw_prompt(msg)
|
1299 |
-
|
1300 |
-
|
1301 |
-
|
1302 |
-
|
1303 |
-
|
1304 |
-
|
1305 |
-
|
1306 |
-
|
|
|
1307 |
|
1308 |
|
|
|
1309 |
@spaces.GPU(duration=60)
|
1310 |
def dolphin_respond_auto(
|
1311 |
message: str,
|
@@ -1319,94 +1334,100 @@ def dolphin_respond_auto(
|
|
1319 |
repeat_penalty: float = 1.1,
|
1320 |
progress=gr.Progress(track_tqdm=True),
|
1321 |
):
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1325 |
|
1326 |
-
|
1327 |
-
|
1328 |
-
|
1329 |
-
|
1330 |
-
|
1331 |
-
|
1332 |
-
|
1333 |
-
|
1334 |
-
|
1335 |
-
n_batch=1024,
|
1336 |
-
n_ctx=8192, #8192
|
1337 |
-
)
|
1338 |
-
provider = LlamaCppPythonProvider(llm)
|
1339 |
-
|
1340 |
-
agent = LlamaCppAgent(
|
1341 |
-
provider,
|
1342 |
-
system_prompt=f"{system_message}",
|
1343 |
-
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
1344 |
-
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
1345 |
-
debug_output=False
|
1346 |
-
)
|
1347 |
-
|
1348 |
-
settings = provider.get_provider_default_settings()
|
1349 |
-
settings.temperature = temperature
|
1350 |
-
settings.top_k = top_k
|
1351 |
-
settings.top_p = top_p
|
1352 |
-
settings.max_tokens = max_tokens
|
1353 |
-
settings.repeat_penalty = repeat_penalty
|
1354 |
-
settings.stream = True
|
1355 |
-
|
1356 |
-
messages = BasicChatHistory()
|
1357 |
-
|
1358 |
-
for msn in history:
|
1359 |
-
user = {
|
1360 |
-
'role': Roles.user,
|
1361 |
-
'content': msn[0]
|
1362 |
-
}
|
1363 |
-
assistant = {
|
1364 |
-
'role': Roles.assistant,
|
1365 |
-
'content': msn[1]
|
1366 |
-
}
|
1367 |
-
messages.add_message(user)
|
1368 |
-
messages.add_message(assistant)
|
1369 |
-
|
1370 |
-
progress(0, desc="Translating...")
|
1371 |
-
stream = agent.get_chat_response(
|
1372 |
-
message,
|
1373 |
-
llm_sampling_settings=settings,
|
1374 |
-
chat_history=messages,
|
1375 |
-
returns_streaming_generator=True,
|
1376 |
-
print_output=False
|
1377 |
-
)
|
1378 |
-
|
1379 |
-
progress(0.5, desc="Processing...")
|
1380 |
-
|
1381 |
-
outputs = ""
|
1382 |
-
for output in stream:
|
1383 |
-
outputs += output
|
1384 |
-
yield [(outputs, None)]
|
1385 |
|
1386 |
|
1387 |
def dolphin_parse_simple(
|
1388 |
message: str,
|
1389 |
history: list[tuple[str, str]],
|
1390 |
):
|
1391 |
-
#if not is_japanese(message): return message
|
1392 |
-
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
|
1393 |
try:
|
|
|
|
|
1394 |
msg = history[-1][0]
|
1395 |
raw_prompt = get_raw_prompt(msg)
|
1396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1397 |
return ""
|
1398 |
-
prompts = []
|
1399 |
-
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
|
1400 |
-
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
|
1401 |
-
else:
|
1402 |
-
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
|
1403 |
-
return ", ".join(prompts)
|
1404 |
|
1405 |
|
1406 |
# https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
|
1407 |
import cv2
|
1408 |
cv2.setNumThreads(1)
|
1409 |
|
|
|
|
|
1410 |
@spaces.GPU()
|
1411 |
def respond_playground(
|
1412 |
message,
|
@@ -1419,47 +1440,47 @@ def respond_playground(
|
|
1419 |
top_k,
|
1420 |
repeat_penalty,
|
1421 |
):
|
1422 |
-
if override_llm_format:
|
1423 |
-
chat_template = override_llm_format
|
1424 |
-
else:
|
1425 |
-
chat_template = llm_models[model][1]
|
1426 |
-
|
1427 |
-
llm = Llama(
|
1428 |
-
model_path=str(Path(f"{llm_models_dir}/{model}")),
|
1429 |
-
flash_attn=True,
|
1430 |
-
n_gpu_layers=81, # 81
|
1431 |
-
n_batch=1024,
|
1432 |
-
n_ctx=8192, #8192
|
1433 |
-
)
|
1434 |
-
provider = LlamaCppPythonProvider(llm)
|
1435 |
-
|
1436 |
-
agent = LlamaCppAgent(
|
1437 |
-
provider,
|
1438 |
-
system_prompt=f"{system_message}",
|
1439 |
-
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
1440 |
-
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
1441 |
-
debug_output=False
|
1442 |
-
)
|
1443 |
-
|
1444 |
-
settings = provider.get_provider_default_settings()
|
1445 |
-
settings.temperature = temperature
|
1446 |
-
settings.top_k = top_k
|
1447 |
-
settings.top_p = top_p
|
1448 |
-
settings.max_tokens = max_tokens
|
1449 |
-
settings.repeat_penalty = repeat_penalty
|
1450 |
-
settings.stream = True
|
1451 |
-
|
1452 |
-
messages = BasicChatHistory()
|
1453 |
-
|
1454 |
-
# Add user and assistant messages to the history
|
1455 |
-
for msn in history:
|
1456 |
-
user = {'role': Roles.user, 'content': msn[0]}
|
1457 |
-
assistant = {'role': Roles.assistant, 'content': msn[1]}
|
1458 |
-
messages.add_message(user)
|
1459 |
-
messages.add_message(assistant)
|
1460 |
-
|
1461 |
-
# Stream the response
|
1462 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1463 |
stream = agent.get_chat_response(
|
1464 |
message,
|
1465 |
llm_sampling_settings=settings,
|
@@ -1473,4 +1494,5 @@ def respond_playground(
|
|
1473 |
outputs += output
|
1474 |
yield outputs
|
1475 |
except Exception as e:
|
1476 |
-
|
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from pathlib import Path
|
4 |
+
import re
|
5 |
+
import torch
|
6 |
+
from huggingface_hub import hf_hub_download, HfApi
|
7 |
from llama_cpp import Llama
|
8 |
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
|
9 |
from llama_cpp_agent.providers import LlamaCppPythonProvider
|
|
|
11 |
from llama_cpp_agent.chat_history.messages import Roles
|
12 |
from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
|
13 |
import wrapt_timeout_decorator
|
|
|
14 |
from llama_cpp_agent.messages_formatter import MessagesFormatter
|
15 |
from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
|
16 |
|
|
|
22 |
#"": ["", MessagesFormatterType.OPEN_CHAT],
|
23 |
#"": ["", MessagesFormatterType.CHATML],
|
24 |
#"": ["", MessagesFormatterType.PHI_3],
|
25 |
+
#"": ["", MessagesFormatterType.GEMMA_2],
|
26 |
"mn-12b-lyra-v2a1-q5_k_m.gguf": ["HalleyStarbun/MN-12B-Lyra-v2a1-Q5_K_M-GGUF", MessagesFormatterType.CHATML],
|
27 |
"L3-8B-Tamamo-v1.i1-Q5_K_M.gguf": ["mradermacher/L3-8B-Tamamo-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
|
28 |
"MN-Chinofun-12B-2.i1-Q4_K_M.gguf": ["mradermacher/MN-Chinofun-12B-2-i1-GGUF", MessagesFormatterType.MISTRAL],
|
|
|
72 |
"ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
|
73 |
"ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
|
74 |
"ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
|
75 |
+
"Magnum_Dark_Madness_12b.Q4_K_S.gguf": ["mradermacher/Magnum_Dark_Madness_12b-GGUF", MessagesFormatterType.MISTRAL],
|
76 |
+
"Magnum_Lyra_Darkness_12b.Q4_K_M.gguf": ["mradermacher/Magnum_Lyra_Darkness_12b-GGUF", MessagesFormatterType.MISTRAL],
|
77 |
+
"Heart_Stolen-8B-task.i1-Q4_K_M.gguf": ["mradermacher/Heart_Stolen-8B-task-i1-GGUF", MessagesFormatterType.LLAMA_3],
|
78 |
+
"Magnum_Backyard_Party_12b.Q4_K_M.gguf": ["mradermacher/Magnum_Backyard_Party_12b-GGUF", MessagesFormatterType.MISTRAL],
|
79 |
+
"Magnum_Madness-12b.Q4_K_M.gguf": ["mradermacher/Magnum_Madness-12b-GGUF", MessagesFormatterType.MISTRAL],
|
80 |
+
"L3.1-Moe-2x8B-v0.2.i1-Q4_K_M.gguf": ["mradermacher/L3.1-Moe-2x8B-v0.2-i1-GGUF", MessagesFormatterType.LLAMA_3],
|
81 |
+
"Qwen2.5-14B-Wernicke-DPO.i1-Q4_K_M.gguf": ["mradermacher/Qwen2.5-14B-Wernicke-DPO-i1-GGUF", MessagesFormatterType.OPEN_CHAT],
|
82 |
+
"Gemma-2-Ataraxy-v4d-9B.i1-Q4_K_M.gguf": ["mradermacher/Gemma-2-Ataraxy-v4d-9B-i1-GGUF", MessagesFormatterType.GEMMA_2],
|
83 |
+
"qwen2.5-14b-megamerge-pt2-q5_k_m.gguf": ["CultriX/Qwen2.5-14B-MegaMerge-pt2-Q5_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
|
84 |
+
"quantqwen2-merged-16bit-q4_k_m.gguf": ["davidbzyk/QuantQwen2-merged-16bit-Q4_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
|
85 |
+
"Mistral-nemo-ja-rp-v0.2-Q4_K_S.gguf": ["ascktgcc/Mistral-nemo-ja-rp-v0.2-GGUF", MessagesFormatterType.MISTRAL],
|
86 |
+
"llama3.1-darkstorm-aspire-8b-q4_k_m.gguf": ["ZeroXClem/Llama3.1-DarkStorm-Aspire-8B-Q4_K_M-GGUF", MessagesFormatterType.LLAMA_3],
|
87 |
+
"llama-3-yggdrasil-astralspice-8b-q4_k_m.gguf": ["ZeroXClem/Llama-3-Yggdrasil-AstralSpice-8B-Q4_K_M-GGUF", MessagesFormatterType.LLAMA_3],
|
88 |
"hermes-llama3-roleplay-1000-v2.Q5_K_M.gguf": ["mradermacher/hermes-llama3-roleplay-1000-v2-GGUF", MessagesFormatterType.LLAMA_3],
|
89 |
"hermes-stheno-8B-v0.1.i1-Q5_K_M.gguf": ["mradermacher/hermes-stheno-8B-v0.1-i1-GGUF", MessagesFormatterType.LLAMA_3],
|
90 |
"qwen-carpmuscle-r-v0.3.Q4_K_M.gguf": ["mradermacher/qwen-carpmuscle-r-v0.3-GGUF", MessagesFormatterType.OPEN_CHAT],
|
|
|
849 |
llm_models_tupled_list = []
|
850 |
default_llm_model_filename = list(llm_models.keys())[0]
|
851 |
override_llm_format = None
|
852 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
853 |
|
854 |
|
855 |
def to_list(s):
|
|
|
862 |
|
863 |
@wrapt_timeout_decorator.timeout(dec_timeout=3.5)
|
864 |
def to_list_ja(s):
|
|
|
865 |
s = re.sub(r'[、。]', ',', s)
|
866 |
return [x.strip() for x in s.split(",") if not s == ""]
|
867 |
|
|
|
876 |
|
877 |
|
878 |
def update_llm_model_tupled_list():
|
|
|
879 |
global llm_models_tupled_list
|
880 |
llm_models_tupled_list = []
|
881 |
for k, v in llm_models.items():
|
|
|
892 |
|
893 |
|
894 |
def download_llm_models():
|
|
|
895 |
global llm_models_tupled_list
|
896 |
llm_models_tupled_list = []
|
897 |
for k, v in llm_models.items():
|
|
|
905 |
|
906 |
|
907 |
def download_llm_model(filename):
|
|
|
908 |
if not filename in llm_models.keys(): return default_llm_model_filename
|
909 |
try:
|
910 |
hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
|
|
|
965 |
|
966 |
|
967 |
def add_dolphin_models(query, format_name):
|
|
|
|
|
968 |
global llm_models
|
969 |
api = HfApi()
|
970 |
add_models = {}
|
|
|
976 |
if s and "" in s: s.remove("")
|
977 |
if len(s) == 1:
|
978 |
repo = s[0]
|
979 |
+
if not api.repo_exists(repo_id = repo): return gr.update()
|
980 |
files = api.list_repo_files(repo_id = repo)
|
981 |
for file in files:
|
982 |
if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
|
983 |
elif len(s) >= 2:
|
984 |
repo = s[0]
|
985 |
filename = s[1]
|
986 |
+
if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update()
|
987 |
add_models[filename] = [repo, format]
|
988 |
+
else: return gr.update()
|
989 |
except Exception as e:
|
990 |
print(e)
|
991 |
+
return gr.update()
|
|
|
992 |
llm_models = (llm_models | add_models).copy()
|
993 |
update_llm_model_tupled_list()
|
994 |
choices = get_dolphin_models()
|
|
|
1188 |
|
1189 |
|
1190 |
def get_dolphin_sysprompt():
|
|
|
1191 |
prompt = re.sub('<LANGUAGE>', dolphin_output_language, dolphin_system_prompt.get(dolphin_sysprompt_mode, ""))
|
1192 |
return prompt
|
1193 |
|
|
|
1217 |
|
1218 |
@wrapt_timeout_decorator.timeout(dec_timeout=5.0)
|
1219 |
def get_raw_prompt(msg: str):
|
|
|
1220 |
m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
|
1221 |
return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
|
1222 |
|
1223 |
|
1224 |
+
@torch.inference_mode()
|
1225 |
@spaces.GPU(duration=60)
|
1226 |
def dolphin_respond(
|
1227 |
message: str,
|
|
|
1235 |
repeat_penalty: float = 1.1,
|
1236 |
progress=gr.Progress(track_tqdm=True),
|
1237 |
):
|
1238 |
+
try:
|
1239 |
+
progress(0, desc="Processing...")
|
1240 |
+
|
1241 |
+
if override_llm_format:
|
1242 |
+
chat_template = override_llm_format
|
1243 |
+
else:
|
1244 |
+
chat_template = llm_models[model][1]
|
1245 |
+
|
1246 |
+
llm = Llama(
|
1247 |
+
model_path=str(Path(f"{llm_models_dir}/{model}")),
|
1248 |
+
flash_attn=True,
|
1249 |
+
n_gpu_layers=81, # 81
|
1250 |
+
n_batch=1024,
|
1251 |
+
n_ctx=8192, #8192
|
1252 |
+
)
|
1253 |
+
provider = LlamaCppPythonProvider(llm)
|
1254 |
+
|
1255 |
+
agent = LlamaCppAgent(
|
1256 |
+
provider,
|
1257 |
+
system_prompt=f"{system_message}",
|
1258 |
+
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
1259 |
+
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
1260 |
+
debug_output=False
|
1261 |
+
)
|
1262 |
+
|
1263 |
+
settings = provider.get_provider_default_settings()
|
1264 |
+
settings.temperature = temperature
|
1265 |
+
settings.top_k = top_k
|
1266 |
+
settings.top_p = top_p
|
1267 |
+
settings.max_tokens = max_tokens
|
1268 |
+
settings.repeat_penalty = repeat_penalty
|
1269 |
+
settings.stream = True
|
1270 |
+
|
1271 |
+
messages = BasicChatHistory()
|
1272 |
+
|
1273 |
+
for msn in history:
|
1274 |
+
user = {
|
1275 |
+
'role': Roles.user,
|
1276 |
+
'content': msn[0]
|
1277 |
+
}
|
1278 |
+
assistant = {
|
1279 |
+
'role': Roles.assistant,
|
1280 |
+
'content': msn[1]
|
1281 |
+
}
|
1282 |
+
messages.add_message(user)
|
1283 |
+
messages.add_message(assistant)
|
1284 |
+
|
1285 |
+
stream = agent.get_chat_response(
|
1286 |
+
message,
|
1287 |
+
llm_sampling_settings=settings,
|
1288 |
+
chat_history=messages,
|
1289 |
+
returns_streaming_generator=True,
|
1290 |
+
print_output=False
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
progress(0.5, desc="Processing...")
|
1294 |
|
1295 |
+
outputs = ""
|
1296 |
+
for output in stream:
|
1297 |
+
outputs += output
|
1298 |
+
yield [(outputs, None)]
|
1299 |
+
except Exception as e:
|
1300 |
+
print(e)
|
1301 |
+
yield [("", None)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1302 |
|
1303 |
|
1304 |
def dolphin_parse(
|
1305 |
history: list[tuple[str, str]],
|
1306 |
):
|
|
|
|
|
1307 |
try:
|
1308 |
+
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
|
1309 |
+
return "", gr.update(), gr.update()
|
1310 |
msg = history[-1][0]
|
1311 |
raw_prompt = get_raw_prompt(msg)
|
1312 |
+
prompts = []
|
1313 |
+
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
|
1314 |
+
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
|
1315 |
+
else:
|
1316 |
+
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
|
1317 |
+
return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
|
1318 |
+
except Exception as e:
|
1319 |
+
print(e)
|
1320 |
+
return "", gr.update(), gr.update()
|
1321 |
|
1322 |
|
1323 |
+
@torch.inference_mode()
|
1324 |
@spaces.GPU(duration=60)
|
1325 |
def dolphin_respond_auto(
|
1326 |
message: str,
|
|
|
1334 |
repeat_penalty: float = 1.1,
|
1335 |
progress=gr.Progress(track_tqdm=True),
|
1336 |
):
|
1337 |
+
try:
|
1338 |
+
#if not is_japanese(message): return [(None, None)]
|
1339 |
+
progress(0, desc="Processing...")
|
1340 |
+
|
1341 |
+
if override_llm_format:
|
1342 |
+
chat_template = override_llm_format
|
1343 |
+
else:
|
1344 |
+
chat_template = llm_models[model][1]
|
1345 |
+
|
1346 |
+
llm = Llama(
|
1347 |
+
model_path=str(Path(f"{llm_models_dir}/{model}")),
|
1348 |
+
flash_attn=True,
|
1349 |
+
n_gpu_layers=81, # 81
|
1350 |
+
n_batch=1024,
|
1351 |
+
n_ctx=8192, #8192
|
1352 |
+
)
|
1353 |
+
provider = LlamaCppPythonProvider(llm)
|
1354 |
+
|
1355 |
+
agent = LlamaCppAgent(
|
1356 |
+
provider,
|
1357 |
+
system_prompt=f"{system_message}",
|
1358 |
+
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
1359 |
+
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
1360 |
+
debug_output=False
|
1361 |
+
)
|
1362 |
+
|
1363 |
+
settings = provider.get_provider_default_settings()
|
1364 |
+
settings.temperature = temperature
|
1365 |
+
settings.top_k = top_k
|
1366 |
+
settings.top_p = top_p
|
1367 |
+
settings.max_tokens = max_tokens
|
1368 |
+
settings.repeat_penalty = repeat_penalty
|
1369 |
+
settings.stream = True
|
1370 |
+
|
1371 |
+
messages = BasicChatHistory()
|
1372 |
+
|
1373 |
+
for msn in history:
|
1374 |
+
user = {
|
1375 |
+
'role': Roles.user,
|
1376 |
+
'content': msn[0]
|
1377 |
+
}
|
1378 |
+
assistant = {
|
1379 |
+
'role': Roles.assistant,
|
1380 |
+
'content': msn[1]
|
1381 |
+
}
|
1382 |
+
messages.add_message(user)
|
1383 |
+
messages.add_message(assistant)
|
1384 |
+
|
1385 |
+
progress(0, desc="Translating...")
|
1386 |
+
stream = agent.get_chat_response(
|
1387 |
+
message,
|
1388 |
+
llm_sampling_settings=settings,
|
1389 |
+
chat_history=messages,
|
1390 |
+
returns_streaming_generator=True,
|
1391 |
+
print_output=False
|
1392 |
+
)
|
1393 |
|
1394 |
+
progress(0.5, desc="Processing...")
|
1395 |
+
|
1396 |
+
outputs = ""
|
1397 |
+
for output in stream:
|
1398 |
+
outputs += output
|
1399 |
+
yield [(outputs, None)], gr.update(), gr.update()
|
1400 |
+
except Exception as e:
|
1401 |
+
print(e)
|
1402 |
+
yield [("", None)], gr.update(), gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1403 |
|
1404 |
|
1405 |
def dolphin_parse_simple(
|
1406 |
message: str,
|
1407 |
history: list[tuple[str, str]],
|
1408 |
):
|
|
|
|
|
1409 |
try:
|
1410 |
+
#if not is_japanese(message): return message
|
1411 |
+
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
|
1412 |
msg = history[-1][0]
|
1413 |
raw_prompt = get_raw_prompt(msg)
|
1414 |
+
prompts = []
|
1415 |
+
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
|
1416 |
+
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
|
1417 |
+
else:
|
1418 |
+
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
|
1419 |
+
return ", ".join(prompts)
|
1420 |
+
except Exception as e:
|
1421 |
+
print(e)
|
1422 |
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
1423 |
|
1424 |
|
1425 |
# https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
|
1426 |
import cv2
|
1427 |
cv2.setNumThreads(1)
|
1428 |
|
1429 |
+
|
1430 |
+
@torch.inference_mode()
|
1431 |
@spaces.GPU()
|
1432 |
def respond_playground(
|
1433 |
message,
|
|
|
1440 |
top_k,
|
1441 |
repeat_penalty,
|
1442 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1443 |
try:
|
1444 |
+
if override_llm_format:
|
1445 |
+
chat_template = override_llm_format
|
1446 |
+
else:
|
1447 |
+
chat_template = llm_models[model][1]
|
1448 |
+
|
1449 |
+
llm = Llama(
|
1450 |
+
model_path=str(Path(f"{llm_models_dir}/{model}")),
|
1451 |
+
flash_attn=True,
|
1452 |
+
n_gpu_layers=81, # 81
|
1453 |
+
n_batch=1024,
|
1454 |
+
n_ctx=8192, #8192
|
1455 |
+
)
|
1456 |
+
provider = LlamaCppPythonProvider(llm)
|
1457 |
+
|
1458 |
+
agent = LlamaCppAgent(
|
1459 |
+
provider,
|
1460 |
+
system_prompt=f"{system_message}",
|
1461 |
+
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
1462 |
+
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
1463 |
+
debug_output=False
|
1464 |
+
)
|
1465 |
+
|
1466 |
+
settings = provider.get_provider_default_settings()
|
1467 |
+
settings.temperature = temperature
|
1468 |
+
settings.top_k = top_k
|
1469 |
+
settings.top_p = top_p
|
1470 |
+
settings.max_tokens = max_tokens
|
1471 |
+
settings.repeat_penalty = repeat_penalty
|
1472 |
+
settings.stream = True
|
1473 |
+
|
1474 |
+
messages = BasicChatHistory()
|
1475 |
+
|
1476 |
+
# Add user and assistant messages to the history
|
1477 |
+
for msn in history:
|
1478 |
+
user = {'role': Roles.user, 'content': msn[0]}
|
1479 |
+
assistant = {'role': Roles.assistant, 'content': msn[1]}
|
1480 |
+
messages.add_message(user)
|
1481 |
+
messages.add_message(assistant)
|
1482 |
+
|
1483 |
+
# Stream the response
|
1484 |
stream = agent.get_chat_response(
|
1485 |
message,
|
1486 |
llm_sampling_settings=settings,
|
|
|
1494 |
outputs += output
|
1495 |
yield outputs
|
1496 |
except Exception as e:
|
1497 |
+
print(e)
|
1498 |
+
yield ""
|
requirements.txt
CHANGED
@@ -10,10 +10,9 @@ accelerate
|
|
10 |
transformers
|
11 |
optimum[onnxruntime]
|
12 |
dartrs
|
13 |
-
|
14 |
-
|
15 |
-
googletrans==4.0.0rc1
|
16 |
-
git+https://github.com/huggingface/diffusers
|
17 |
rapidfuzz
|
18 |
wrapt-timeout-decorator
|
19 |
-
opencv-python
|
|
|
|
10 |
transformers
|
11 |
optimum[onnxruntime]
|
12 |
dartrs
|
13 |
+
translatepy
|
14 |
+
diffusers
|
|
|
|
|
15 |
rapidfuzz
|
16 |
wrapt-timeout-decorator
|
17 |
+
opencv-python
|
18 |
+
numpy<2
|
tagger/character_series_dict.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger/danbooru_e621.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger/tag_group.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger/tagger.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
10 |
+
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
11 |
+
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
default_device = device
|
14 |
+
|
15 |
+
try:
|
16 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
|
17 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
18 |
+
except Exception as e:
|
19 |
+
print(e)
|
20 |
+
wd_model = wd_processor = None
|
21 |
+
|
22 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
23 |
+
return (
|
24 |
+
[f"1{noun}"]
|
25 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
26 |
+
+ [f"{maximum+1}+{noun}s"]
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
PEOPLE_TAGS = (
|
31 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
RATING_MAP = {
|
36 |
+
"sfw": "safe",
|
37 |
+
"general": "safe",
|
38 |
+
"sensitive": "sensitive",
|
39 |
+
"questionable": "nsfw",
|
40 |
+
"explicit": "explicit, nsfw",
|
41 |
+
}
|
42 |
+
DANBOORU_TO_E621_RATING_MAP = {
|
43 |
+
"sfw": "rating_safe",
|
44 |
+
"general": "rating_safe",
|
45 |
+
"safe": "rating_safe",
|
46 |
+
"sensitive": "rating_safe",
|
47 |
+
"nsfw": "rating_explicit",
|
48 |
+
"explicit, nsfw": "rating_explicit",
|
49 |
+
"explicit": "rating_explicit",
|
50 |
+
"rating:safe": "rating_safe",
|
51 |
+
"rating:general": "rating_safe",
|
52 |
+
"rating:sensitive": "rating_safe",
|
53 |
+
"rating:questionable, nsfw": "rating_explicit",
|
54 |
+
"rating:explicit, nsfw": "rating_explicit",
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
59 |
+
kaomojis = [
|
60 |
+
"0_0",
|
61 |
+
"(o)_(o)",
|
62 |
+
"+_+",
|
63 |
+
"+_-",
|
64 |
+
"._.",
|
65 |
+
"<o>_<o>",
|
66 |
+
"<|>_<|>",
|
67 |
+
"=_=",
|
68 |
+
">_<",
|
69 |
+
"3_3",
|
70 |
+
"6_9",
|
71 |
+
">_o",
|
72 |
+
"@_@",
|
73 |
+
"^_^",
|
74 |
+
"o_o",
|
75 |
+
"u_u",
|
76 |
+
"x_x",
|
77 |
+
"|_|",
|
78 |
+
"||_||",
|
79 |
+
]
|
80 |
+
|
81 |
+
|
82 |
+
def replace_underline(x: str):
|
83 |
+
return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
|
84 |
+
|
85 |
+
|
86 |
+
def to_list(s):
|
87 |
+
return [x.strip() for x in s.split(",") if not s == ""]
|
88 |
+
|
89 |
+
|
90 |
+
def list_sub(a, b):
|
91 |
+
return [e for e in a if e not in b]
|
92 |
+
|
93 |
+
|
94 |
+
def list_uniq(l):
|
95 |
+
return sorted(set(l), key=l.index)
|
96 |
+
|
97 |
+
|
98 |
+
def load_dict_from_csv(filename):
|
99 |
+
dict = {}
|
100 |
+
if not Path(filename).exists():
|
101 |
+
if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
|
102 |
+
else: return dict
|
103 |
+
try:
|
104 |
+
with open(filename, 'r', encoding="utf-8") as f:
|
105 |
+
lines = f.readlines()
|
106 |
+
except Exception:
|
107 |
+
print(f"Failed to open dictionary file: {filename}")
|
108 |
+
return dict
|
109 |
+
for line in lines:
|
110 |
+
parts = line.strip().split(',')
|
111 |
+
dict[parts[0]] = parts[1]
|
112 |
+
return dict
|
113 |
+
|
114 |
+
|
115 |
+
anime_series_dict = load_dict_from_csv('character_series_dict.csv')
|
116 |
+
|
117 |
+
|
118 |
+
def character_list_to_series_list(character_list):
|
119 |
+
output_series_tag = []
|
120 |
+
series_tag = ""
|
121 |
+
series_dict = anime_series_dict
|
122 |
+
for tag in character_list:
|
123 |
+
series_tag = series_dict.get(tag, "")
|
124 |
+
if tag.endswith(")"):
|
125 |
+
tags = tag.split("(")
|
126 |
+
character_tag = "(".join(tags[:-1])
|
127 |
+
if character_tag.endswith(" "):
|
128 |
+
character_tag = character_tag[:-1]
|
129 |
+
series_tag = tags[-1].replace(")", "")
|
130 |
+
|
131 |
+
if series_tag:
|
132 |
+
output_series_tag.append(series_tag)
|
133 |
+
|
134 |
+
return output_series_tag
|
135 |
+
|
136 |
+
|
137 |
+
def select_random_character(series: str, character: str):
|
138 |
+
from random import seed, randrange
|
139 |
+
seed()
|
140 |
+
character_list = list(anime_series_dict.keys())
|
141 |
+
character = character_list[randrange(len(character_list) - 1)]
|
142 |
+
series = anime_series_dict.get(character.split(",")[0].strip(), "")
|
143 |
+
return series, character
|
144 |
+
|
145 |
+
|
146 |
+
def danbooru_to_e621(dtag, e621_dict):
|
147 |
+
def d_to_e(match, e621_dict):
|
148 |
+
dtag = match.group(0)
|
149 |
+
etag = e621_dict.get(replace_underline(dtag), "")
|
150 |
+
if etag:
|
151 |
+
return etag
|
152 |
+
else:
|
153 |
+
return dtag
|
154 |
+
|
155 |
+
import re
|
156 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
|
157 |
+
return tag
|
158 |
+
|
159 |
+
|
160 |
+
danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
|
161 |
+
|
162 |
+
|
163 |
+
def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
|
164 |
+
if prompt_type == "danbooru": return input_prompt
|
165 |
+
tags = input_prompt.split(",") if input_prompt else []
|
166 |
+
people_tags: list[str] = []
|
167 |
+
other_tags: list[str] = []
|
168 |
+
rating_tags: list[str] = []
|
169 |
+
|
170 |
+
e621_dict = danbooru_to_e621_dict
|
171 |
+
for tag in tags:
|
172 |
+
tag = replace_underline(tag)
|
173 |
+
tag = danbooru_to_e621(tag, e621_dict)
|
174 |
+
if tag in PEOPLE_TAGS:
|
175 |
+
people_tags.append(tag)
|
176 |
+
elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
|
177 |
+
rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
|
178 |
+
else:
|
179 |
+
other_tags.append(tag)
|
180 |
+
|
181 |
+
rating_tags = sorted(set(rating_tags), key=rating_tags.index)
|
182 |
+
rating_tags = [rating_tags[0]] if rating_tags else []
|
183 |
+
rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
|
184 |
+
|
185 |
+
output_prompt = ", ".join(people_tags + other_tags + rating_tags)
|
186 |
+
|
187 |
+
return output_prompt
|
188 |
+
|
189 |
+
|
190 |
+
from translatepy import Translator
|
191 |
+
translator = Translator()
|
192 |
+
def translate_prompt_old(prompt: str = ""):
|
193 |
+
def translate_to_english(input: str):
|
194 |
+
try:
|
195 |
+
output = str(translator.translate(input, 'English'))
|
196 |
+
except Exception as e:
|
197 |
+
output = input
|
198 |
+
print(e)
|
199 |
+
return output
|
200 |
+
|
201 |
+
def is_japanese(s):
|
202 |
+
import unicodedata
|
203 |
+
for ch in s:
|
204 |
+
name = unicodedata.name(ch, "")
|
205 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
206 |
+
return True
|
207 |
+
return False
|
208 |
+
|
209 |
+
def to_list(s):
|
210 |
+
return [x.strip() for x in s.split(",")]
|
211 |
+
|
212 |
+
prompts = to_list(prompt)
|
213 |
+
outputs = []
|
214 |
+
for p in prompts:
|
215 |
+
p = translate_to_english(p) if is_japanese(p) else p
|
216 |
+
outputs.append(p)
|
217 |
+
|
218 |
+
return ", ".join(outputs)
|
219 |
+
|
220 |
+
|
221 |
+
def translate_prompt(input: str):
|
222 |
+
try:
|
223 |
+
output = str(translator.translate(input, 'English'))
|
224 |
+
except Exception as e:
|
225 |
+
output = input
|
226 |
+
print(e)
|
227 |
+
return output
|
228 |
+
|
229 |
+
|
230 |
+
def translate_prompt_to_ja(prompt: str = ""):
|
231 |
+
def translate_to_japanese(input: str):
|
232 |
+
try:
|
233 |
+
output = str(translator.translate(input, 'Japanese'))
|
234 |
+
except Exception as e:
|
235 |
+
output = input
|
236 |
+
print(e)
|
237 |
+
return output
|
238 |
+
|
239 |
+
def is_japanese(s):
|
240 |
+
import unicodedata
|
241 |
+
for ch in s:
|
242 |
+
name = unicodedata.name(ch, "")
|
243 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
244 |
+
return True
|
245 |
+
return False
|
246 |
+
|
247 |
+
def to_list(s):
|
248 |
+
return [x.strip() for x in s.split(",")]
|
249 |
+
|
250 |
+
prompts = to_list(prompt)
|
251 |
+
outputs = []
|
252 |
+
for p in prompts:
|
253 |
+
p = translate_to_japanese(p) if not is_japanese(p) else p
|
254 |
+
outputs.append(p)
|
255 |
+
|
256 |
+
return ", ".join(outputs)
|
257 |
+
|
258 |
+
|
259 |
+
def tags_to_ja(itag, dict):
|
260 |
+
def t_to_j(match, dict):
|
261 |
+
tag = match.group(0)
|
262 |
+
ja = dict.get(replace_underline(tag), "")
|
263 |
+
if ja:
|
264 |
+
return ja
|
265 |
+
else:
|
266 |
+
return tag
|
267 |
+
|
268 |
+
import re
|
269 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
|
270 |
+
|
271 |
+
return tag
|
272 |
+
|
273 |
+
|
274 |
+
def convert_tags_to_ja(input_prompt: str = ""):
|
275 |
+
tags = input_prompt.split(",") if input_prompt else []
|
276 |
+
out_tags = []
|
277 |
+
|
278 |
+
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
279 |
+
dict = tags_to_ja_dict
|
280 |
+
for tag in tags:
|
281 |
+
tag = replace_underline(tag)
|
282 |
+
tag = tags_to_ja(tag, dict)
|
283 |
+
out_tags.append(tag)
|
284 |
+
|
285 |
+
return ", ".join(out_tags)
|
286 |
+
|
287 |
+
|
288 |
+
enable_auto_recom_prompt = True
|
289 |
+
|
290 |
+
|
291 |
+
animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
|
292 |
+
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
293 |
+
pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
|
294 |
+
pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
|
295 |
+
other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
|
296 |
+
other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
|
297 |
+
default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
|
298 |
+
default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
299 |
+
def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
|
300 |
+
global enable_auto_recom_prompt
|
301 |
+
prompts = to_list(prompt)
|
302 |
+
neg_prompts = to_list(neg_prompt)
|
303 |
+
|
304 |
+
prompts = list_sub(prompts, animagine_ps + pony_ps)
|
305 |
+
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
|
306 |
+
|
307 |
+
last_empty_p = [""] if not prompts and type != "None" else []
|
308 |
+
last_empty_np = [""] if not neg_prompts and type != "None" else []
|
309 |
+
|
310 |
+
if type == "Auto":
|
311 |
+
enable_auto_recom_prompt = True
|
312 |
+
else:
|
313 |
+
enable_auto_recom_prompt = False
|
314 |
+
if type == "Animagine":
|
315 |
+
prompts = prompts + animagine_ps
|
316 |
+
neg_prompts = neg_prompts + animagine_nps
|
317 |
+
elif type == "Pony":
|
318 |
+
prompts = prompts + pony_ps
|
319 |
+
neg_prompts = neg_prompts + pony_nps
|
320 |
+
|
321 |
+
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
322 |
+
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
323 |
+
|
324 |
+
return prompt, neg_prompt
|
325 |
+
|
326 |
+
|
327 |
+
def load_model_prompt_dict():
|
328 |
+
import json
|
329 |
+
dict = {}
|
330 |
+
path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
|
331 |
+
try:
|
332 |
+
with open('model_dict.json', encoding='utf-8') as f:
|
333 |
+
dict = json.load(f)
|
334 |
+
except Exception:
|
335 |
+
pass
|
336 |
+
return dict
|
337 |
+
|
338 |
+
|
339 |
+
model_prompt_dict = load_model_prompt_dict()
|
340 |
+
|
341 |
+
|
342 |
+
def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
|
343 |
+
if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
|
344 |
+
prompts = to_list(prompt)
|
345 |
+
neg_prompts = to_list(neg_prompt)
|
346 |
+
prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
|
347 |
+
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
|
348 |
+
last_empty_p = [""] if not prompts and type != "None" else []
|
349 |
+
last_empty_np = [""] if not neg_prompts and type != "None" else []
|
350 |
+
ps = []
|
351 |
+
nps = []
|
352 |
+
if model_name in model_prompt_dict.keys():
|
353 |
+
ps = to_list(model_prompt_dict[model_name]["prompt"])
|
354 |
+
nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
|
355 |
+
else:
|
356 |
+
ps = default_ps
|
357 |
+
nps = default_nps
|
358 |
+
prompts = prompts + ps
|
359 |
+
neg_prompts = neg_prompts + nps
|
360 |
+
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
361 |
+
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
362 |
+
return prompt, neg_prompt
|
363 |
+
|
364 |
+
|
365 |
+
tag_group_dict = load_dict_from_csv('tag_group.csv')
|
366 |
+
|
367 |
+
|
368 |
+
def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
369 |
+
def is_dressed(tag):
|
370 |
+
import re
|
371 |
+
p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
|
372 |
+
return p.search(tag)
|
373 |
+
|
374 |
+
def is_background(tag):
|
375 |
+
import re
|
376 |
+
p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
|
377 |
+
return p.search(tag)
|
378 |
+
|
379 |
+
un_tags = ['solo']
|
380 |
+
group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
|
381 |
+
keep_group_dict = {
|
382 |
+
"body": ['groups', 'body_parts'],
|
383 |
+
"dress": ['groups', 'body_parts', 'attire'],
|
384 |
+
"all": group_list,
|
385 |
+
}
|
386 |
+
|
387 |
+
def is_necessary(tag, keep_tags, group_dict):
|
388 |
+
if keep_tags == "all":
|
389 |
+
return True
|
390 |
+
elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
|
391 |
+
return False
|
392 |
+
elif keep_tags == "body" and is_dressed(tag):
|
393 |
+
return False
|
394 |
+
elif is_background(tag):
|
395 |
+
return False
|
396 |
+
else:
|
397 |
+
return True
|
398 |
+
|
399 |
+
if keep_tags == "all": return input_prompt
|
400 |
+
keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
|
401 |
+
explicit_group = list(set(group_list) ^ set(keep_group))
|
402 |
+
|
403 |
+
tags = input_prompt.split(",") if input_prompt else []
|
404 |
+
people_tags: list[str] = []
|
405 |
+
other_tags: list[str] = []
|
406 |
+
|
407 |
+
group_dict = tag_group_dict
|
408 |
+
for tag in tags:
|
409 |
+
tag = replace_underline(tag)
|
410 |
+
if tag in PEOPLE_TAGS:
|
411 |
+
people_tags.append(tag)
|
412 |
+
elif is_necessary(tag, keep_tags, group_dict):
|
413 |
+
other_tags.append(tag)
|
414 |
+
|
415 |
+
output_prompt = ", ".join(people_tags + other_tags)
|
416 |
+
|
417 |
+
return output_prompt
|
418 |
+
|
419 |
+
|
420 |
+
def sort_taglist(tags: list[str]):
|
421 |
+
if not tags: return []
|
422 |
+
character_tags: list[str] = []
|
423 |
+
series_tags: list[str] = []
|
424 |
+
people_tags: list[str] = []
|
425 |
+
group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
|
426 |
+
group_tags = {}
|
427 |
+
other_tags: list[str] = []
|
428 |
+
rating_tags: list[str] = []
|
429 |
+
|
430 |
+
group_dict = tag_group_dict
|
431 |
+
group_set = set(group_dict.keys())
|
432 |
+
character_set = set(anime_series_dict.keys())
|
433 |
+
series_set = set(anime_series_dict.values())
|
434 |
+
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
435 |
+
|
436 |
+
for tag in tags:
|
437 |
+
tag = replace_underline(tag)
|
438 |
+
if tag in PEOPLE_TAGS:
|
439 |
+
people_tags.append(tag)
|
440 |
+
elif tag in rating_set:
|
441 |
+
rating_tags.append(tag)
|
442 |
+
elif tag in group_set:
|
443 |
+
elem = group_dict[tag]
|
444 |
+
group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
|
445 |
+
elif tag in character_set:
|
446 |
+
character_tags.append(tag)
|
447 |
+
elif tag in series_set:
|
448 |
+
series_tags.append(tag)
|
449 |
+
else:
|
450 |
+
other_tags.append(tag)
|
451 |
+
|
452 |
+
output_group_tags: list[str] = []
|
453 |
+
for k in group_list:
|
454 |
+
output_group_tags.extend(group_tags.get(k, []))
|
455 |
+
|
456 |
+
rating_tags = [rating_tags[0]] if rating_tags else []
|
457 |
+
rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
|
458 |
+
|
459 |
+
output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
|
460 |
+
|
461 |
+
return output_tags
|
462 |
+
|
463 |
+
|
464 |
+
def sort_tags(tags: str):
|
465 |
+
if not tags: return ""
|
466 |
+
taglist: list[str] = []
|
467 |
+
for tag in tags.split(","):
|
468 |
+
taglist.append(tag.strip())
|
469 |
+
taglist = list(filter(lambda x: x != "", taglist))
|
470 |
+
return ", ".join(sort_taglist(taglist))
|
471 |
+
|
472 |
+
|
473 |
+
def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
|
474 |
+
results = {
|
475 |
+
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
476 |
+
}
|
477 |
+
|
478 |
+
rating = {}
|
479 |
+
character = {}
|
480 |
+
general = {}
|
481 |
+
|
482 |
+
for k, v in results.items():
|
483 |
+
if k.startswith("rating:"):
|
484 |
+
rating[k.replace("rating:", "")] = v
|
485 |
+
continue
|
486 |
+
elif k.startswith("character:"):
|
487 |
+
character[k.replace("character:", "")] = v
|
488 |
+
continue
|
489 |
+
|
490 |
+
general[k] = v
|
491 |
+
|
492 |
+
character = {k: v for k, v in character.items() if v >= character_threshold}
|
493 |
+
general = {k: v for k, v in general.items() if v >= general_threshold}
|
494 |
+
|
495 |
+
return rating, character, general
|
496 |
+
|
497 |
+
|
498 |
+
def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
499 |
+
people_tags: list[str] = []
|
500 |
+
other_tags: list[str] = []
|
501 |
+
rating_tag = RATING_MAP[rating[0]]
|
502 |
+
|
503 |
+
for tag in general:
|
504 |
+
if tag in PEOPLE_TAGS:
|
505 |
+
people_tags.append(tag)
|
506 |
+
else:
|
507 |
+
other_tags.append(tag)
|
508 |
+
|
509 |
+
all_tags = people_tags + other_tags
|
510 |
+
|
511 |
+
return ", ".join(all_tags)
|
512 |
+
|
513 |
+
|
514 |
+
@spaces.GPU(duration=30)
|
515 |
+
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
516 |
+
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
517 |
+
|
518 |
+
outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
|
519 |
+
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
520 |
+
|
521 |
+
# get probabilities
|
522 |
+
if device != default_device: wd_model.to(device=device)
|
523 |
+
results = {
|
524 |
+
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
525 |
+
}
|
526 |
+
if device != default_device: wd_model.to(device=default_device)
|
527 |
+
# rating, character, general
|
528 |
+
rating, character, general = postprocess_results(
|
529 |
+
results, general_threshold, character_threshold
|
530 |
+
)
|
531 |
+
prompt = gen_prompt(
|
532 |
+
list(rating.keys()), list(character.keys()), list(general.keys())
|
533 |
+
)
|
534 |
+
output_series_tag = ""
|
535 |
+
output_series_list = character_list_to_series_list(character.keys())
|
536 |
+
if output_series_list:
|
537 |
+
output_series_tag = output_series_list[0]
|
538 |
+
else:
|
539 |
+
output_series_tag = ""
|
540 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
541 |
+
|
542 |
+
|
543 |
+
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
544 |
+
character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
|
545 |
+
if not "Use WD Tagger" in algo and len(algo) != 0:
|
546 |
+
return input_series, input_character, input_tags, gr.update(interactive=True)
|
547 |
+
return predict_tags(image, general_threshold, character_threshold)
|
548 |
+
|
549 |
+
|
550 |
+
def compose_prompt_to_copy(character: str, series: str, general: str):
|
551 |
+
characters = character.split(",") if character else []
|
552 |
+
serieses = series.split(",") if series else []
|
553 |
+
generals = general.split(",") if general else []
|
554 |
+
tags = characters + serieses + generals
|
555 |
+
cprompt = ",".join(tags) if tags else ""
|
556 |
+
return cprompt
|
tagger/utils.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
|
3 |
+
|
4 |
+
|
5 |
+
V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
|
6 |
+
"ultra_wide",
|
7 |
+
"wide",
|
8 |
+
"square",
|
9 |
+
"tall",
|
10 |
+
"ultra_tall",
|
11 |
+
]
|
12 |
+
V2_RATING_OPTIONS: list[RatingTag] = [
|
13 |
+
"sfw",
|
14 |
+
"general",
|
15 |
+
"sensitive",
|
16 |
+
"nsfw",
|
17 |
+
"questionable",
|
18 |
+
"explicit",
|
19 |
+
]
|
20 |
+
V2_LENGTH_OPTIONS: list[LengthTag] = [
|
21 |
+
"very_short",
|
22 |
+
"short",
|
23 |
+
"medium",
|
24 |
+
"long",
|
25 |
+
"very_long",
|
26 |
+
]
|
27 |
+
V2_IDENTITY_OPTIONS: list[IdentityTag] = [
|
28 |
+
"none",
|
29 |
+
"lax",
|
30 |
+
"strict",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
|
35 |
+
def gradio_copy_text(_text: None):
|
36 |
+
gr.Info("Copied!")
|
37 |
+
|
38 |
+
|
39 |
+
COPY_ACTION_JS = """\
|
40 |
+
(inputs, _outputs) => {
|
41 |
+
// inputs is the string value of the input_text
|
42 |
+
if (inputs.trim() !== "") {
|
43 |
+
navigator.clipboard.writeText(inputs);
|
44 |
+
}
|
45 |
+
}"""
|
46 |
+
|
47 |
+
|
48 |
+
def gradio_copy_prompt(prompt: str):
|
49 |
+
gr.Info("Copied!")
|
50 |
+
return prompt
|