John6666 commited on
Commit
b15c679
β€’
1 Parent(s): 0c7cfe4

Upload 11 files

Browse files
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import spaces
2
  import gradio as gr
3
- from utils import gradio_copy_text, COPY_ACTION_JS
4
- from tagger import convert_danbooru_to_e621_prompt, insert_recom_prompt
5
  from genimage import generate_image
6
  from llmdolphin import (get_llm_formats, get_dolphin_model_format,
7
  get_dolphin_models, get_dolphin_model_info, select_dolphin_model,
@@ -59,7 +59,8 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
59
  recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
60
  generate_image_btn = gr.Button(value="GENERATE IMAGE", size="lg", variant="primary")
61
  with gr.Row():
62
- result_image = gr.Gallery(label="Generated images", columns=1, object_fit="contain", container=True, preview=True, show_label=False, show_share_button=False, show_download_button=True, interactive=False, visible=True, format="png")
 
63
  with gr.Tab("GGUF-Playground"):
64
  gr.Markdown("""# Chat with lots of Models and LLMs using llama.cpp
65
  This tab is copy of [CaioXapelaum/GGUF-Playground](https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground).<br>
 
1
  import spaces
2
  import gradio as gr
3
+ from tagger.utils import gradio_copy_text, COPY_ACTION_JS
4
+ from tagger.tagger import convert_danbooru_to_e621_prompt, insert_recom_prompt
5
  from genimage import generate_image
6
  from llmdolphin import (get_llm_formats, get_dolphin_model_format,
7
  get_dolphin_models, get_dolphin_model_info, select_dolphin_model,
 
59
  recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
60
  generate_image_btn = gr.Button(value="GENERATE IMAGE", size="lg", variant="primary")
61
  with gr.Row():
62
+ result_image = gr.Gallery(label="Generated images", columns=1, object_fit="contain", container=True, preview=True, height=512,
63
+ show_label=False, show_share_button=False, show_download_button=True, interactive=False, visible=True, format="png")
64
  with gr.Tab("GGUF-Playground"):
65
  gr.Markdown("""# Chat with lots of Models and LLMs using llama.cpp
66
  This tab is copy of [CaioXapelaum/GGUF-Playground](https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground).<br>
genimage.py CHANGED
@@ -1,20 +1,49 @@
1
  import spaces
 
2
 
3
 
4
  def load_pipeline():
5
- from diffusers import AutoPipelineForText2Image
6
- import torch
7
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
- pipe = AutoPipelineForText2Image.from_pretrained(
9
  "John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
10
- #custom_pipeline="lpw_stable_diffusion_xl",
11
- custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
12
  torch_dtype=torch.float16,
13
  )
14
  pipe.to(device)
15
  return pipe
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def save_image(image, metadata, output_dir):
19
  import os
20
  import uuid
@@ -33,26 +62,30 @@ def save_image(image, metadata, output_dir):
33
  pipe = load_pipeline()
34
 
35
 
 
36
  @spaces.GPU
37
  def generate_image(prompt, neg_prompt):
 
 
38
  metadata = {
39
- "prompt": prompt + ", anime, masterpiece, best quality, very aesthetic, absurdres",
40
- "negative_prompt": neg_prompt + ", bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast",
41
  "resolution": f"{1024} x {1024}",
42
  "guidance_scale": 7.0,
43
  "num_inference_steps": 28,
44
  "sampler": "Euler",
45
  }
46
  try:
 
47
  images = pipe(
48
- prompt=prompt + ", anime, masterpiece, best quality, very aesthetic, absurdres",
49
- negative_prompt=neg_prompt + ", bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast",
50
  width=1024,
51
  height=1024,
52
- guidance_scale=7.0, seg_scale=3.0, seg_applied_layers=["mid"],
53
  num_inference_steps=28,
54
  output_type="pil",
55
- #clip_skip=1,
56
  ).images
57
  if images:
58
  image_paths = [
 
1
  import spaces
2
+ import torch
3
 
4
 
5
  def load_pipeline():
6
+ from diffusers import DiffusionPipeline
 
7
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
+ pipe = DiffusionPipeline.from_pretrained(
9
  "John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
10
+ custom_pipeline="lpw_stable_diffusion_xl",
11
+ #custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
12
  torch_dtype=torch.float16,
13
  )
14
  pipe.to(device)
15
  return pipe
16
 
17
 
18
+ def token_auto_concat_embeds(pipe, positive, negative):
19
+ max_length = pipe.tokenizer.model_max_length
20
+ positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1]
21
+ negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1]
22
+
23
+ print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.')
24
+ if max_length < positive_length or max_length < negative_length:
25
+ print('Concatenated embedding.')
26
+ if positive_length > negative_length:
27
+ positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda")
28
+ negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
29
+ else:
30
+ negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda")
31
+ positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
32
+ else:
33
+ positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
34
+ negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
35
+
36
+ positive_concat_embeds = []
37
+ negative_concat_embeds = []
38
+ for i in range(0, positive_ids.shape[-1], max_length):
39
+ positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0])
40
+ negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
41
+
42
+ positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1)
43
+ negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1)
44
+ return positive_prompt_embeds, negative_prompt_embeds
45
+
46
+
47
  def save_image(image, metadata, output_dir):
48
  import os
49
  import uuid
 
62
  pipe = load_pipeline()
63
 
64
 
65
+ @torch.inference_mode()
66
  @spaces.GPU
67
  def generate_image(prompt, neg_prompt):
68
+ prompt += ", anime, masterpiece, best quality, very aesthetic, absurdres"
69
+ neg_prompt += ", bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast"
70
  metadata = {
71
+ "prompt": prompt,
72
+ "negative_prompt": neg_prompt,
73
  "resolution": f"{1024} x {1024}",
74
  "guidance_scale": 7.0,
75
  "num_inference_steps": 28,
76
  "sampler": "Euler",
77
  }
78
  try:
79
+ #positive_embeds, negative_embeds = token_auto_concat_embeds(pipe, prompt, neg_prompt)
80
  images = pipe(
81
+ prompt=prompt,
82
+ negative_prompt=neg_prompt,
83
  width=1024,
84
  height=1024,
85
+ guidance_scale=7.0,# seg_scale=3.0, seg_applied_layers=["mid"],
86
  num_inference_steps=28,
87
  output_type="pil",
88
+ clip_skip=1,
89
  ).images
90
  if images:
91
  image_paths = [
llmdolphin.py CHANGED
@@ -1,5 +1,9 @@
1
  import spaces
2
  import gradio as gr
 
 
 
 
3
  from llama_cpp import Llama
4
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
5
  from llama_cpp_agent.providers import LlamaCppPythonProvider
@@ -7,7 +11,6 @@ from llama_cpp_agent.chat_history import BasicChatHistory
7
  from llama_cpp_agent.chat_history.messages import Roles
8
  from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
9
  import wrapt_timeout_decorator
10
- from pathlib import Path
11
  from llama_cpp_agent.messages_formatter import MessagesFormatter
12
  from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
13
 
@@ -19,6 +22,7 @@ llm_models = {
19
  #"": ["", MessagesFormatterType.OPEN_CHAT],
20
  #"": ["", MessagesFormatterType.CHATML],
21
  #"": ["", MessagesFormatterType.PHI_3],
 
22
  "mn-12b-lyra-v2a1-q5_k_m.gguf": ["HalleyStarbun/MN-12B-Lyra-v2a1-Q5_K_M-GGUF", MessagesFormatterType.CHATML],
23
  "L3-8B-Tamamo-v1.i1-Q5_K_M.gguf": ["mradermacher/L3-8B-Tamamo-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
24
  "MN-Chinofun-12B-2.i1-Q4_K_M.gguf": ["mradermacher/MN-Chinofun-12B-2-i1-GGUF", MessagesFormatterType.MISTRAL],
@@ -68,6 +72,19 @@ llm_models = {
68
  "ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
69
  "ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
70
  "ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  "hermes-llama3-roleplay-1000-v2.Q5_K_M.gguf": ["mradermacher/hermes-llama3-roleplay-1000-v2-GGUF", MessagesFormatterType.LLAMA_3],
72
  "hermes-stheno-8B-v0.1.i1-Q5_K_M.gguf": ["mradermacher/hermes-stheno-8B-v0.1-i1-GGUF", MessagesFormatterType.LLAMA_3],
73
  "qwen-carpmuscle-r-v0.3.Q4_K_M.gguf": ["mradermacher/qwen-carpmuscle-r-v0.3-GGUF", MessagesFormatterType.OPEN_CHAT],
@@ -832,6 +849,7 @@ llm_languages = ["English", "Japanese", "Chinese", "Korean", "Spanish", "Portugu
832
  llm_models_tupled_list = []
833
  default_llm_model_filename = list(llm_models.keys())[0]
834
  override_llm_format = None
 
835
 
836
 
837
  def to_list(s):
@@ -844,7 +862,6 @@ def list_uniq(l):
844
 
845
  @wrapt_timeout_decorator.timeout(dec_timeout=3.5)
846
  def to_list_ja(s):
847
- import re
848
  s = re.sub(r'[、。]', ',', s)
849
  return [x.strip() for x in s.split(",") if not s == ""]
850
 
@@ -859,7 +876,6 @@ def is_japanese(s):
859
 
860
 
861
  def update_llm_model_tupled_list():
862
- from pathlib import Path
863
  global llm_models_tupled_list
864
  llm_models_tupled_list = []
865
  for k, v in llm_models.items():
@@ -876,7 +892,6 @@ def update_llm_model_tupled_list():
876
 
877
 
878
  def download_llm_models():
879
- from huggingface_hub import hf_hub_download
880
  global llm_models_tupled_list
881
  llm_models_tupled_list = []
882
  for k, v in llm_models.items():
@@ -890,7 +905,6 @@ def download_llm_models():
890
 
891
 
892
  def download_llm_model(filename):
893
- from huggingface_hub import hf_hub_download
894
  if not filename in llm_models.keys(): return default_llm_model_filename
895
  try:
896
  hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
@@ -951,8 +965,6 @@ def get_dolphin_model_format(filename):
951
 
952
 
953
  def add_dolphin_models(query, format_name):
954
- import re
955
- from huggingface_hub import HfApi
956
  global llm_models
957
  api = HfApi()
958
  add_models = {}
@@ -964,20 +976,19 @@ def add_dolphin_models(query, format_name):
964
  if s and "" in s: s.remove("")
965
  if len(s) == 1:
966
  repo = s[0]
967
- if not api.repo_exists(repo_id = repo): return gr.update(visible=True)
968
  files = api.list_repo_files(repo_id = repo)
969
  for file in files:
970
  if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
971
  elif len(s) >= 2:
972
  repo = s[0]
973
  filename = s[1]
974
- if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update(visible=True)
975
  add_models[filename] = [repo, format]
976
- else: return gr.update(visible=True)
977
  except Exception as e:
978
  print(e)
979
- return gr.update(visible=True)
980
- #print(add_models)
981
  llm_models = (llm_models | add_models).copy()
982
  update_llm_model_tupled_list()
983
  choices = get_dolphin_models()
@@ -1177,7 +1188,6 @@ Output should be enclosed in //GENBEGIN//:// and //://GENEND//. The text to be g
1177
 
1178
 
1179
  def get_dolphin_sysprompt():
1180
- import re
1181
  prompt = re.sub('<LANGUAGE>', dolphin_output_language, dolphin_system_prompt.get(dolphin_sysprompt_mode, ""))
1182
  return prompt
1183
 
@@ -1207,11 +1217,11 @@ def select_dolphin_language(lang: str):
1207
 
1208
  @wrapt_timeout_decorator.timeout(dec_timeout=5.0)
1209
  def get_raw_prompt(msg: str):
1210
- import re
1211
  m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
1212
  return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
1213
 
1214
 
 
1215
  @spaces.GPU(duration=60)
1216
  def dolphin_respond(
1217
  message: str,
@@ -1225,87 +1235,92 @@ def dolphin_respond(
1225
  repeat_penalty: float = 1.1,
1226
  progress=gr.Progress(track_tqdm=True),
1227
  ):
1228
- from pathlib import Path
1229
- progress(0, desc="Processing...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1230
 
1231
- if override_llm_format:
1232
- chat_template = override_llm_format
1233
- else:
1234
- chat_template = llm_models[model][1]
1235
-
1236
- llm = Llama(
1237
- model_path=str(Path(f"{llm_models_dir}/{model}")),
1238
- flash_attn=True,
1239
- n_gpu_layers=81, # 81
1240
- n_batch=1024,
1241
- n_ctx=8192, #8192
1242
- )
1243
- provider = LlamaCppPythonProvider(llm)
1244
-
1245
- agent = LlamaCppAgent(
1246
- provider,
1247
- system_prompt=f"{system_message}",
1248
- predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1249
- custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1250
- debug_output=False
1251
- )
1252
-
1253
- settings = provider.get_provider_default_settings()
1254
- settings.temperature = temperature
1255
- settings.top_k = top_k
1256
- settings.top_p = top_p
1257
- settings.max_tokens = max_tokens
1258
- settings.repeat_penalty = repeat_penalty
1259
- settings.stream = True
1260
-
1261
- messages = BasicChatHistory()
1262
-
1263
- for msn in history:
1264
- user = {
1265
- 'role': Roles.user,
1266
- 'content': msn[0]
1267
- }
1268
- assistant = {
1269
- 'role': Roles.assistant,
1270
- 'content': msn[1]
1271
- }
1272
- messages.add_message(user)
1273
- messages.add_message(assistant)
1274
-
1275
- stream = agent.get_chat_response(
1276
- message,
1277
- llm_sampling_settings=settings,
1278
- chat_history=messages,
1279
- returns_streaming_generator=True,
1280
- print_output=False
1281
- )
1282
-
1283
- progress(0.5, desc="Processing...")
1284
-
1285
- outputs = ""
1286
- for output in stream:
1287
- outputs += output
1288
- yield [(outputs, None)]
1289
 
1290
 
1291
  def dolphin_parse(
1292
  history: list[tuple[str, str]],
1293
  ):
1294
- if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
1295
- return "", gr.update(visible=True), gr.update(visible=True)
1296
  try:
 
 
1297
  msg = history[-1][0]
1298
  raw_prompt = get_raw_prompt(msg)
1299
- except Exception:
1300
- return "", gr.update(visible=True), gr.update(visible=True)
1301
- prompts = []
1302
- if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1303
- prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
1304
- else:
1305
- prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
1306
- return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
 
1307
 
1308
 
 
1309
  @spaces.GPU(duration=60)
1310
  def dolphin_respond_auto(
1311
  message: str,
@@ -1319,94 +1334,100 @@ def dolphin_respond_auto(
1319
  repeat_penalty: float = 1.1,
1320
  progress=gr.Progress(track_tqdm=True),
1321
  ):
1322
- #if not is_japanese(message): return [(None, None)]
1323
- from pathlib import Path
1324
- progress(0, desc="Processing...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
- if override_llm_format:
1327
- chat_template = override_llm_format
1328
- else:
1329
- chat_template = llm_models[model][1]
1330
-
1331
- llm = Llama(
1332
- model_path=str(Path(f"{llm_models_dir}/{model}")),
1333
- flash_attn=True,
1334
- n_gpu_layers=81, # 81
1335
- n_batch=1024,
1336
- n_ctx=8192, #8192
1337
- )
1338
- provider = LlamaCppPythonProvider(llm)
1339
-
1340
- agent = LlamaCppAgent(
1341
- provider,
1342
- system_prompt=f"{system_message}",
1343
- predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1344
- custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1345
- debug_output=False
1346
- )
1347
-
1348
- settings = provider.get_provider_default_settings()
1349
- settings.temperature = temperature
1350
- settings.top_k = top_k
1351
- settings.top_p = top_p
1352
- settings.max_tokens = max_tokens
1353
- settings.repeat_penalty = repeat_penalty
1354
- settings.stream = True
1355
-
1356
- messages = BasicChatHistory()
1357
-
1358
- for msn in history:
1359
- user = {
1360
- 'role': Roles.user,
1361
- 'content': msn[0]
1362
- }
1363
- assistant = {
1364
- 'role': Roles.assistant,
1365
- 'content': msn[1]
1366
- }
1367
- messages.add_message(user)
1368
- messages.add_message(assistant)
1369
-
1370
- progress(0, desc="Translating...")
1371
- stream = agent.get_chat_response(
1372
- message,
1373
- llm_sampling_settings=settings,
1374
- chat_history=messages,
1375
- returns_streaming_generator=True,
1376
- print_output=False
1377
- )
1378
-
1379
- progress(0.5, desc="Processing...")
1380
-
1381
- outputs = ""
1382
- for output in stream:
1383
- outputs += output
1384
- yield [(outputs, None)]
1385
 
1386
 
1387
  def dolphin_parse_simple(
1388
  message: str,
1389
  history: list[tuple[str, str]],
1390
  ):
1391
- #if not is_japanese(message): return message
1392
- if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
1393
  try:
 
 
1394
  msg = history[-1][0]
1395
  raw_prompt = get_raw_prompt(msg)
1396
- except Exception:
 
 
 
 
 
 
 
1397
  return ""
1398
- prompts = []
1399
- if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1400
- prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
1401
- else:
1402
- prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
1403
- return ", ".join(prompts)
1404
 
1405
 
1406
  # https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
1407
  import cv2
1408
  cv2.setNumThreads(1)
1409
 
 
 
1410
  @spaces.GPU()
1411
  def respond_playground(
1412
  message,
@@ -1419,47 +1440,47 @@ def respond_playground(
1419
  top_k,
1420
  repeat_penalty,
1421
  ):
1422
- if override_llm_format:
1423
- chat_template = override_llm_format
1424
- else:
1425
- chat_template = llm_models[model][1]
1426
-
1427
- llm = Llama(
1428
- model_path=str(Path(f"{llm_models_dir}/{model}")),
1429
- flash_attn=True,
1430
- n_gpu_layers=81, # 81
1431
- n_batch=1024,
1432
- n_ctx=8192, #8192
1433
- )
1434
- provider = LlamaCppPythonProvider(llm)
1435
-
1436
- agent = LlamaCppAgent(
1437
- provider,
1438
- system_prompt=f"{system_message}",
1439
- predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1440
- custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1441
- debug_output=False
1442
- )
1443
-
1444
- settings = provider.get_provider_default_settings()
1445
- settings.temperature = temperature
1446
- settings.top_k = top_k
1447
- settings.top_p = top_p
1448
- settings.max_tokens = max_tokens
1449
- settings.repeat_penalty = repeat_penalty
1450
- settings.stream = True
1451
-
1452
- messages = BasicChatHistory()
1453
-
1454
- # Add user and assistant messages to the history
1455
- for msn in history:
1456
- user = {'role': Roles.user, 'content': msn[0]}
1457
- assistant = {'role': Roles.assistant, 'content': msn[1]}
1458
- messages.add_message(user)
1459
- messages.add_message(assistant)
1460
-
1461
- # Stream the response
1462
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1463
  stream = agent.get_chat_response(
1464
  message,
1465
  llm_sampling_settings=settings,
@@ -1473,4 +1494,5 @@ def respond_playground(
1473
  outputs += output
1474
  yield outputs
1475
  except Exception as e:
1476
- yield f"Error during response generation: {str(e)}"
 
 
1
  import spaces
2
  import gradio as gr
3
+ from pathlib import Path
4
+ import re
5
+ import torch
6
+ from huggingface_hub import hf_hub_download, HfApi
7
  from llama_cpp import Llama
8
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
9
  from llama_cpp_agent.providers import LlamaCppPythonProvider
 
11
  from llama_cpp_agent.chat_history.messages import Roles
12
  from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
13
  import wrapt_timeout_decorator
 
14
  from llama_cpp_agent.messages_formatter import MessagesFormatter
15
  from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
16
 
 
22
  #"": ["", MessagesFormatterType.OPEN_CHAT],
23
  #"": ["", MessagesFormatterType.CHATML],
24
  #"": ["", MessagesFormatterType.PHI_3],
25
+ #"": ["", MessagesFormatterType.GEMMA_2],
26
  "mn-12b-lyra-v2a1-q5_k_m.gguf": ["HalleyStarbun/MN-12B-Lyra-v2a1-Q5_K_M-GGUF", MessagesFormatterType.CHATML],
27
  "L3-8B-Tamamo-v1.i1-Q5_K_M.gguf": ["mradermacher/L3-8B-Tamamo-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
28
  "MN-Chinofun-12B-2.i1-Q4_K_M.gguf": ["mradermacher/MN-Chinofun-12B-2-i1-GGUF", MessagesFormatterType.MISTRAL],
 
72
  "ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
73
  "ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
74
  "ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
75
+ "Magnum_Dark_Madness_12b.Q4_K_S.gguf": ["mradermacher/Magnum_Dark_Madness_12b-GGUF", MessagesFormatterType.MISTRAL],
76
+ "Magnum_Lyra_Darkness_12b.Q4_K_M.gguf": ["mradermacher/Magnum_Lyra_Darkness_12b-GGUF", MessagesFormatterType.MISTRAL],
77
+ "Heart_Stolen-8B-task.i1-Q4_K_M.gguf": ["mradermacher/Heart_Stolen-8B-task-i1-GGUF", MessagesFormatterType.LLAMA_3],
78
+ "Magnum_Backyard_Party_12b.Q4_K_M.gguf": ["mradermacher/Magnum_Backyard_Party_12b-GGUF", MessagesFormatterType.MISTRAL],
79
+ "Magnum_Madness-12b.Q4_K_M.gguf": ["mradermacher/Magnum_Madness-12b-GGUF", MessagesFormatterType.MISTRAL],
80
+ "L3.1-Moe-2x8B-v0.2.i1-Q4_K_M.gguf": ["mradermacher/L3.1-Moe-2x8B-v0.2-i1-GGUF", MessagesFormatterType.LLAMA_3],
81
+ "Qwen2.5-14B-Wernicke-DPO.i1-Q4_K_M.gguf": ["mradermacher/Qwen2.5-14B-Wernicke-DPO-i1-GGUF", MessagesFormatterType.OPEN_CHAT],
82
+ "Gemma-2-Ataraxy-v4d-9B.i1-Q4_K_M.gguf": ["mradermacher/Gemma-2-Ataraxy-v4d-9B-i1-GGUF", MessagesFormatterType.GEMMA_2],
83
+ "qwen2.5-14b-megamerge-pt2-q5_k_m.gguf": ["CultriX/Qwen2.5-14B-MegaMerge-pt2-Q5_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
84
+ "quantqwen2-merged-16bit-q4_k_m.gguf": ["davidbzyk/QuantQwen2-merged-16bit-Q4_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
85
+ "Mistral-nemo-ja-rp-v0.2-Q4_K_S.gguf": ["ascktgcc/Mistral-nemo-ja-rp-v0.2-GGUF", MessagesFormatterType.MISTRAL],
86
+ "llama3.1-darkstorm-aspire-8b-q4_k_m.gguf": ["ZeroXClem/Llama3.1-DarkStorm-Aspire-8B-Q4_K_M-GGUF", MessagesFormatterType.LLAMA_3],
87
+ "llama-3-yggdrasil-astralspice-8b-q4_k_m.gguf": ["ZeroXClem/Llama-3-Yggdrasil-AstralSpice-8B-Q4_K_M-GGUF", MessagesFormatterType.LLAMA_3],
88
  "hermes-llama3-roleplay-1000-v2.Q5_K_M.gguf": ["mradermacher/hermes-llama3-roleplay-1000-v2-GGUF", MessagesFormatterType.LLAMA_3],
89
  "hermes-stheno-8B-v0.1.i1-Q5_K_M.gguf": ["mradermacher/hermes-stheno-8B-v0.1-i1-GGUF", MessagesFormatterType.LLAMA_3],
90
  "qwen-carpmuscle-r-v0.3.Q4_K_M.gguf": ["mradermacher/qwen-carpmuscle-r-v0.3-GGUF", MessagesFormatterType.OPEN_CHAT],
 
849
  llm_models_tupled_list = []
850
  default_llm_model_filename = list(llm_models.keys())[0]
851
  override_llm_format = None
852
+ device = "cuda" if torch.cuda.is_available() else "cpu"
853
 
854
 
855
  def to_list(s):
 
862
 
863
  @wrapt_timeout_decorator.timeout(dec_timeout=3.5)
864
  def to_list_ja(s):
 
865
  s = re.sub(r'[、。]', ',', s)
866
  return [x.strip() for x in s.split(",") if not s == ""]
867
 
 
876
 
877
 
878
  def update_llm_model_tupled_list():
 
879
  global llm_models_tupled_list
880
  llm_models_tupled_list = []
881
  for k, v in llm_models.items():
 
892
 
893
 
894
  def download_llm_models():
 
895
  global llm_models_tupled_list
896
  llm_models_tupled_list = []
897
  for k, v in llm_models.items():
 
905
 
906
 
907
  def download_llm_model(filename):
 
908
  if not filename in llm_models.keys(): return default_llm_model_filename
909
  try:
910
  hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
 
965
 
966
 
967
  def add_dolphin_models(query, format_name):
 
 
968
  global llm_models
969
  api = HfApi()
970
  add_models = {}
 
976
  if s and "" in s: s.remove("")
977
  if len(s) == 1:
978
  repo = s[0]
979
+ if not api.repo_exists(repo_id = repo): return gr.update()
980
  files = api.list_repo_files(repo_id = repo)
981
  for file in files:
982
  if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
983
  elif len(s) >= 2:
984
  repo = s[0]
985
  filename = s[1]
986
+ if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update()
987
  add_models[filename] = [repo, format]
988
+ else: return gr.update()
989
  except Exception as e:
990
  print(e)
991
+ return gr.update()
 
992
  llm_models = (llm_models | add_models).copy()
993
  update_llm_model_tupled_list()
994
  choices = get_dolphin_models()
 
1188
 
1189
 
1190
  def get_dolphin_sysprompt():
 
1191
  prompt = re.sub('<LANGUAGE>', dolphin_output_language, dolphin_system_prompt.get(dolphin_sysprompt_mode, ""))
1192
  return prompt
1193
 
 
1217
 
1218
  @wrapt_timeout_decorator.timeout(dec_timeout=5.0)
1219
  def get_raw_prompt(msg: str):
 
1220
  m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
1221
  return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
1222
 
1223
 
1224
+ @torch.inference_mode()
1225
  @spaces.GPU(duration=60)
1226
  def dolphin_respond(
1227
  message: str,
 
1235
  repeat_penalty: float = 1.1,
1236
  progress=gr.Progress(track_tqdm=True),
1237
  ):
1238
+ try:
1239
+ progress(0, desc="Processing...")
1240
+
1241
+ if override_llm_format:
1242
+ chat_template = override_llm_format
1243
+ else:
1244
+ chat_template = llm_models[model][1]
1245
+
1246
+ llm = Llama(
1247
+ model_path=str(Path(f"{llm_models_dir}/{model}")),
1248
+ flash_attn=True,
1249
+ n_gpu_layers=81, # 81
1250
+ n_batch=1024,
1251
+ n_ctx=8192, #8192
1252
+ )
1253
+ provider = LlamaCppPythonProvider(llm)
1254
+
1255
+ agent = LlamaCppAgent(
1256
+ provider,
1257
+ system_prompt=f"{system_message}",
1258
+ predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1259
+ custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1260
+ debug_output=False
1261
+ )
1262
+
1263
+ settings = provider.get_provider_default_settings()
1264
+ settings.temperature = temperature
1265
+ settings.top_k = top_k
1266
+ settings.top_p = top_p
1267
+ settings.max_tokens = max_tokens
1268
+ settings.repeat_penalty = repeat_penalty
1269
+ settings.stream = True
1270
+
1271
+ messages = BasicChatHistory()
1272
+
1273
+ for msn in history:
1274
+ user = {
1275
+ 'role': Roles.user,
1276
+ 'content': msn[0]
1277
+ }
1278
+ assistant = {
1279
+ 'role': Roles.assistant,
1280
+ 'content': msn[1]
1281
+ }
1282
+ messages.add_message(user)
1283
+ messages.add_message(assistant)
1284
+
1285
+ stream = agent.get_chat_response(
1286
+ message,
1287
+ llm_sampling_settings=settings,
1288
+ chat_history=messages,
1289
+ returns_streaming_generator=True,
1290
+ print_output=False
1291
+ )
1292
+
1293
+ progress(0.5, desc="Processing...")
1294
 
1295
+ outputs = ""
1296
+ for output in stream:
1297
+ outputs += output
1298
+ yield [(outputs, None)]
1299
+ except Exception as e:
1300
+ print(e)
1301
+ yield [("", None)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1302
 
1303
 
1304
  def dolphin_parse(
1305
  history: list[tuple[str, str]],
1306
  ):
 
 
1307
  try:
1308
+ if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
1309
+ return "", gr.update(), gr.update()
1310
  msg = history[-1][0]
1311
  raw_prompt = get_raw_prompt(msg)
1312
+ prompts = []
1313
+ if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1314
+ prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
1315
+ else:
1316
+ prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
1317
+ return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
1318
+ except Exception as e:
1319
+ print(e)
1320
+ return "", gr.update(), gr.update()
1321
 
1322
 
1323
+ @torch.inference_mode()
1324
  @spaces.GPU(duration=60)
1325
  def dolphin_respond_auto(
1326
  message: str,
 
1334
  repeat_penalty: float = 1.1,
1335
  progress=gr.Progress(track_tqdm=True),
1336
  ):
1337
+ try:
1338
+ #if not is_japanese(message): return [(None, None)]
1339
+ progress(0, desc="Processing...")
1340
+
1341
+ if override_llm_format:
1342
+ chat_template = override_llm_format
1343
+ else:
1344
+ chat_template = llm_models[model][1]
1345
+
1346
+ llm = Llama(
1347
+ model_path=str(Path(f"{llm_models_dir}/{model}")),
1348
+ flash_attn=True,
1349
+ n_gpu_layers=81, # 81
1350
+ n_batch=1024,
1351
+ n_ctx=8192, #8192
1352
+ )
1353
+ provider = LlamaCppPythonProvider(llm)
1354
+
1355
+ agent = LlamaCppAgent(
1356
+ provider,
1357
+ system_prompt=f"{system_message}",
1358
+ predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1359
+ custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1360
+ debug_output=False
1361
+ )
1362
+
1363
+ settings = provider.get_provider_default_settings()
1364
+ settings.temperature = temperature
1365
+ settings.top_k = top_k
1366
+ settings.top_p = top_p
1367
+ settings.max_tokens = max_tokens
1368
+ settings.repeat_penalty = repeat_penalty
1369
+ settings.stream = True
1370
+
1371
+ messages = BasicChatHistory()
1372
+
1373
+ for msn in history:
1374
+ user = {
1375
+ 'role': Roles.user,
1376
+ 'content': msn[0]
1377
+ }
1378
+ assistant = {
1379
+ 'role': Roles.assistant,
1380
+ 'content': msn[1]
1381
+ }
1382
+ messages.add_message(user)
1383
+ messages.add_message(assistant)
1384
+
1385
+ progress(0, desc="Translating...")
1386
+ stream = agent.get_chat_response(
1387
+ message,
1388
+ llm_sampling_settings=settings,
1389
+ chat_history=messages,
1390
+ returns_streaming_generator=True,
1391
+ print_output=False
1392
+ )
1393
 
1394
+ progress(0.5, desc="Processing...")
1395
+
1396
+ outputs = ""
1397
+ for output in stream:
1398
+ outputs += output
1399
+ yield [(outputs, None)], gr.update(), gr.update()
1400
+ except Exception as e:
1401
+ print(e)
1402
+ yield [("", None)], gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1403
 
1404
 
1405
  def dolphin_parse_simple(
1406
  message: str,
1407
  history: list[tuple[str, str]],
1408
  ):
 
 
1409
  try:
1410
+ #if not is_japanese(message): return message
1411
+ if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
1412
  msg = history[-1][0]
1413
  raw_prompt = get_raw_prompt(msg)
1414
+ prompts = []
1415
+ if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1416
+ prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
1417
+ else:
1418
+ prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
1419
+ return ", ".join(prompts)
1420
+ except Exception as e:
1421
+ print(e)
1422
  return ""
 
 
 
 
 
 
1423
 
1424
 
1425
  # https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
1426
  import cv2
1427
  cv2.setNumThreads(1)
1428
 
1429
+
1430
+ @torch.inference_mode()
1431
  @spaces.GPU()
1432
  def respond_playground(
1433
  message,
 
1440
  top_k,
1441
  repeat_penalty,
1442
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1443
  try:
1444
+ if override_llm_format:
1445
+ chat_template = override_llm_format
1446
+ else:
1447
+ chat_template = llm_models[model][1]
1448
+
1449
+ llm = Llama(
1450
+ model_path=str(Path(f"{llm_models_dir}/{model}")),
1451
+ flash_attn=True,
1452
+ n_gpu_layers=81, # 81
1453
+ n_batch=1024,
1454
+ n_ctx=8192, #8192
1455
+ )
1456
+ provider = LlamaCppPythonProvider(llm)
1457
+
1458
+ agent = LlamaCppAgent(
1459
+ provider,
1460
+ system_prompt=f"{system_message}",
1461
+ predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1462
+ custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1463
+ debug_output=False
1464
+ )
1465
+
1466
+ settings = provider.get_provider_default_settings()
1467
+ settings.temperature = temperature
1468
+ settings.top_k = top_k
1469
+ settings.top_p = top_p
1470
+ settings.max_tokens = max_tokens
1471
+ settings.repeat_penalty = repeat_penalty
1472
+ settings.stream = True
1473
+
1474
+ messages = BasicChatHistory()
1475
+
1476
+ # Add user and assistant messages to the history
1477
+ for msn in history:
1478
+ user = {'role': Roles.user, 'content': msn[0]}
1479
+ assistant = {'role': Roles.assistant, 'content': msn[1]}
1480
+ messages.add_message(user)
1481
+ messages.add_message(assistant)
1482
+
1483
+ # Stream the response
1484
  stream = agent.get_chat_response(
1485
  message,
1486
  llm_sampling_settings=settings,
 
1494
  outputs += output
1495
  yield outputs
1496
  except Exception as e:
1497
+ print(e)
1498
+ yield ""
requirements.txt CHANGED
@@ -10,10 +10,9 @@ accelerate
10
  transformers
11
  optimum[onnxruntime]
12
  dartrs
13
- httpx==0.13.3
14
- httpcore
15
- googletrans==4.0.0rc1
16
- git+https://github.com/huggingface/diffusers
17
  rapidfuzz
18
  wrapt-timeout-decorator
19
- opencv-python
 
 
10
  transformers
11
  optimum[onnxruntime]
12
  dartrs
13
+ translatepy
14
+ diffusers
 
 
15
  rapidfuzz
16
  wrapt-timeout-decorator
17
+ opencv-python
18
+ numpy<2
tagger/character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/tag_group.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/tagger.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from PIL import Image
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+ from pathlib import Path
7
+
8
+
9
+ WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
10
+ WD_MODEL_NAME = WD_MODEL_NAMES[0]
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ default_device = device
14
+
15
+ try:
16
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
17
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
+ except Exception as e:
19
+ print(e)
20
+ wd_model = wd_processor = None
21
+
22
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
23
+ return (
24
+ [f"1{noun}"]
25
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
26
+ + [f"{maximum+1}+{noun}s"]
27
+ )
28
+
29
+
30
+ PEOPLE_TAGS = (
31
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
32
+ )
33
+
34
+
35
+ RATING_MAP = {
36
+ "sfw": "safe",
37
+ "general": "safe",
38
+ "sensitive": "sensitive",
39
+ "questionable": "nsfw",
40
+ "explicit": "explicit, nsfw",
41
+ }
42
+ DANBOORU_TO_E621_RATING_MAP = {
43
+ "sfw": "rating_safe",
44
+ "general": "rating_safe",
45
+ "safe": "rating_safe",
46
+ "sensitive": "rating_safe",
47
+ "nsfw": "rating_explicit",
48
+ "explicit, nsfw": "rating_explicit",
49
+ "explicit": "rating_explicit",
50
+ "rating:safe": "rating_safe",
51
+ "rating:general": "rating_safe",
52
+ "rating:sensitive": "rating_safe",
53
+ "rating:questionable, nsfw": "rating_explicit",
54
+ "rating:explicit, nsfw": "rating_explicit",
55
+ }
56
+
57
+
58
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
59
+ kaomojis = [
60
+ "0_0",
61
+ "(o)_(o)",
62
+ "+_+",
63
+ "+_-",
64
+ "._.",
65
+ "<o>_<o>",
66
+ "<|>_<|>",
67
+ "=_=",
68
+ ">_<",
69
+ "3_3",
70
+ "6_9",
71
+ ">_o",
72
+ "@_@",
73
+ "^_^",
74
+ "o_o",
75
+ "u_u",
76
+ "x_x",
77
+ "|_|",
78
+ "||_||",
79
+ ]
80
+
81
+
82
+ def replace_underline(x: str):
83
+ return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
84
+
85
+
86
+ def to_list(s):
87
+ return [x.strip() for x in s.split(",") if not s == ""]
88
+
89
+
90
+ def list_sub(a, b):
91
+ return [e for e in a if e not in b]
92
+
93
+
94
+ def list_uniq(l):
95
+ return sorted(set(l), key=l.index)
96
+
97
+
98
+ def load_dict_from_csv(filename):
99
+ dict = {}
100
+ if not Path(filename).exists():
101
+ if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
102
+ else: return dict
103
+ try:
104
+ with open(filename, 'r', encoding="utf-8") as f:
105
+ lines = f.readlines()
106
+ except Exception:
107
+ print(f"Failed to open dictionary file: {filename}")
108
+ return dict
109
+ for line in lines:
110
+ parts = line.strip().split(',')
111
+ dict[parts[0]] = parts[1]
112
+ return dict
113
+
114
+
115
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
116
+
117
+
118
+ def character_list_to_series_list(character_list):
119
+ output_series_tag = []
120
+ series_tag = ""
121
+ series_dict = anime_series_dict
122
+ for tag in character_list:
123
+ series_tag = series_dict.get(tag, "")
124
+ if tag.endswith(")"):
125
+ tags = tag.split("(")
126
+ character_tag = "(".join(tags[:-1])
127
+ if character_tag.endswith(" "):
128
+ character_tag = character_tag[:-1]
129
+ series_tag = tags[-1].replace(")", "")
130
+
131
+ if series_tag:
132
+ output_series_tag.append(series_tag)
133
+
134
+ return output_series_tag
135
+
136
+
137
+ def select_random_character(series: str, character: str):
138
+ from random import seed, randrange
139
+ seed()
140
+ character_list = list(anime_series_dict.keys())
141
+ character = character_list[randrange(len(character_list) - 1)]
142
+ series = anime_series_dict.get(character.split(",")[0].strip(), "")
143
+ return series, character
144
+
145
+
146
+ def danbooru_to_e621(dtag, e621_dict):
147
+ def d_to_e(match, e621_dict):
148
+ dtag = match.group(0)
149
+ etag = e621_dict.get(replace_underline(dtag), "")
150
+ if etag:
151
+ return etag
152
+ else:
153
+ return dtag
154
+
155
+ import re
156
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
157
+ return tag
158
+
159
+
160
+ danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
161
+
162
+
163
+ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
164
+ if prompt_type == "danbooru": return input_prompt
165
+ tags = input_prompt.split(",") if input_prompt else []
166
+ people_tags: list[str] = []
167
+ other_tags: list[str] = []
168
+ rating_tags: list[str] = []
169
+
170
+ e621_dict = danbooru_to_e621_dict
171
+ for tag in tags:
172
+ tag = replace_underline(tag)
173
+ tag = danbooru_to_e621(tag, e621_dict)
174
+ if tag in PEOPLE_TAGS:
175
+ people_tags.append(tag)
176
+ elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
177
+ rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
178
+ else:
179
+ other_tags.append(tag)
180
+
181
+ rating_tags = sorted(set(rating_tags), key=rating_tags.index)
182
+ rating_tags = [rating_tags[0]] if rating_tags else []
183
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
184
+
185
+ output_prompt = ", ".join(people_tags + other_tags + rating_tags)
186
+
187
+ return output_prompt
188
+
189
+
190
+ from translatepy import Translator
191
+ translator = Translator()
192
+ def translate_prompt_old(prompt: str = ""):
193
+ def translate_to_english(input: str):
194
+ try:
195
+ output = str(translator.translate(input, 'English'))
196
+ except Exception as e:
197
+ output = input
198
+ print(e)
199
+ return output
200
+
201
+ def is_japanese(s):
202
+ import unicodedata
203
+ for ch in s:
204
+ name = unicodedata.name(ch, "")
205
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
206
+ return True
207
+ return False
208
+
209
+ def to_list(s):
210
+ return [x.strip() for x in s.split(",")]
211
+
212
+ prompts = to_list(prompt)
213
+ outputs = []
214
+ for p in prompts:
215
+ p = translate_to_english(p) if is_japanese(p) else p
216
+ outputs.append(p)
217
+
218
+ return ", ".join(outputs)
219
+
220
+
221
+ def translate_prompt(input: str):
222
+ try:
223
+ output = str(translator.translate(input, 'English'))
224
+ except Exception as e:
225
+ output = input
226
+ print(e)
227
+ return output
228
+
229
+
230
+ def translate_prompt_to_ja(prompt: str = ""):
231
+ def translate_to_japanese(input: str):
232
+ try:
233
+ output = str(translator.translate(input, 'Japanese'))
234
+ except Exception as e:
235
+ output = input
236
+ print(e)
237
+ return output
238
+
239
+ def is_japanese(s):
240
+ import unicodedata
241
+ for ch in s:
242
+ name = unicodedata.name(ch, "")
243
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
244
+ return True
245
+ return False
246
+
247
+ def to_list(s):
248
+ return [x.strip() for x in s.split(",")]
249
+
250
+ prompts = to_list(prompt)
251
+ outputs = []
252
+ for p in prompts:
253
+ p = translate_to_japanese(p) if not is_japanese(p) else p
254
+ outputs.append(p)
255
+
256
+ return ", ".join(outputs)
257
+
258
+
259
+ def tags_to_ja(itag, dict):
260
+ def t_to_j(match, dict):
261
+ tag = match.group(0)
262
+ ja = dict.get(replace_underline(tag), "")
263
+ if ja:
264
+ return ja
265
+ else:
266
+ return tag
267
+
268
+ import re
269
+ tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
270
+
271
+ return tag
272
+
273
+
274
+ def convert_tags_to_ja(input_prompt: str = ""):
275
+ tags = input_prompt.split(",") if input_prompt else []
276
+ out_tags = []
277
+
278
+ tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
279
+ dict = tags_to_ja_dict
280
+ for tag in tags:
281
+ tag = replace_underline(tag)
282
+ tag = tags_to_ja(tag, dict)
283
+ out_tags.append(tag)
284
+
285
+ return ", ".join(out_tags)
286
+
287
+
288
+ enable_auto_recom_prompt = True
289
+
290
+
291
+ animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
292
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
293
+ pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
294
+ pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
295
+ other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
296
+ other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
297
+ default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
298
+ default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
299
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
300
+ global enable_auto_recom_prompt
301
+ prompts = to_list(prompt)
302
+ neg_prompts = to_list(neg_prompt)
303
+
304
+ prompts = list_sub(prompts, animagine_ps + pony_ps)
305
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
306
+
307
+ last_empty_p = [""] if not prompts and type != "None" else []
308
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
309
+
310
+ if type == "Auto":
311
+ enable_auto_recom_prompt = True
312
+ else:
313
+ enable_auto_recom_prompt = False
314
+ if type == "Animagine":
315
+ prompts = prompts + animagine_ps
316
+ neg_prompts = neg_prompts + animagine_nps
317
+ elif type == "Pony":
318
+ prompts = prompts + pony_ps
319
+ neg_prompts = neg_prompts + pony_nps
320
+
321
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
322
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
323
+
324
+ return prompt, neg_prompt
325
+
326
+
327
+ def load_model_prompt_dict():
328
+ import json
329
+ dict = {}
330
+ path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
331
+ try:
332
+ with open('model_dict.json', encoding='utf-8') as f:
333
+ dict = json.load(f)
334
+ except Exception:
335
+ pass
336
+ return dict
337
+
338
+
339
+ model_prompt_dict = load_model_prompt_dict()
340
+
341
+
342
+ def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
343
+ if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
344
+ prompts = to_list(prompt)
345
+ neg_prompts = to_list(neg_prompt)
346
+ prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
347
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
348
+ last_empty_p = [""] if not prompts and type != "None" else []
349
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
350
+ ps = []
351
+ nps = []
352
+ if model_name in model_prompt_dict.keys():
353
+ ps = to_list(model_prompt_dict[model_name]["prompt"])
354
+ nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
355
+ else:
356
+ ps = default_ps
357
+ nps = default_nps
358
+ prompts = prompts + ps
359
+ neg_prompts = neg_prompts + nps
360
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
361
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
362
+ return prompt, neg_prompt
363
+
364
+
365
+ tag_group_dict = load_dict_from_csv('tag_group.csv')
366
+
367
+
368
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
369
+ def is_dressed(tag):
370
+ import re
371
+ p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
372
+ return p.search(tag)
373
+
374
+ def is_background(tag):
375
+ import re
376
+ p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
377
+ return p.search(tag)
378
+
379
+ un_tags = ['solo']
380
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
381
+ keep_group_dict = {
382
+ "body": ['groups', 'body_parts'],
383
+ "dress": ['groups', 'body_parts', 'attire'],
384
+ "all": group_list,
385
+ }
386
+
387
+ def is_necessary(tag, keep_tags, group_dict):
388
+ if keep_tags == "all":
389
+ return True
390
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
391
+ return False
392
+ elif keep_tags == "body" and is_dressed(tag):
393
+ return False
394
+ elif is_background(tag):
395
+ return False
396
+ else:
397
+ return True
398
+
399
+ if keep_tags == "all": return input_prompt
400
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
401
+ explicit_group = list(set(group_list) ^ set(keep_group))
402
+
403
+ tags = input_prompt.split(",") if input_prompt else []
404
+ people_tags: list[str] = []
405
+ other_tags: list[str] = []
406
+
407
+ group_dict = tag_group_dict
408
+ for tag in tags:
409
+ tag = replace_underline(tag)
410
+ if tag in PEOPLE_TAGS:
411
+ people_tags.append(tag)
412
+ elif is_necessary(tag, keep_tags, group_dict):
413
+ other_tags.append(tag)
414
+
415
+ output_prompt = ", ".join(people_tags + other_tags)
416
+
417
+ return output_prompt
418
+
419
+
420
+ def sort_taglist(tags: list[str]):
421
+ if not tags: return []
422
+ character_tags: list[str] = []
423
+ series_tags: list[str] = []
424
+ people_tags: list[str] = []
425
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
426
+ group_tags = {}
427
+ other_tags: list[str] = []
428
+ rating_tags: list[str] = []
429
+
430
+ group_dict = tag_group_dict
431
+ group_set = set(group_dict.keys())
432
+ character_set = set(anime_series_dict.keys())
433
+ series_set = set(anime_series_dict.values())
434
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
435
+
436
+ for tag in tags:
437
+ tag = replace_underline(tag)
438
+ if tag in PEOPLE_TAGS:
439
+ people_tags.append(tag)
440
+ elif tag in rating_set:
441
+ rating_tags.append(tag)
442
+ elif tag in group_set:
443
+ elem = group_dict[tag]
444
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
445
+ elif tag in character_set:
446
+ character_tags.append(tag)
447
+ elif tag in series_set:
448
+ series_tags.append(tag)
449
+ else:
450
+ other_tags.append(tag)
451
+
452
+ output_group_tags: list[str] = []
453
+ for k in group_list:
454
+ output_group_tags.extend(group_tags.get(k, []))
455
+
456
+ rating_tags = [rating_tags[0]] if rating_tags else []
457
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
458
+
459
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
460
+
461
+ return output_tags
462
+
463
+
464
+ def sort_tags(tags: str):
465
+ if not tags: return ""
466
+ taglist: list[str] = []
467
+ for tag in tags.split(","):
468
+ taglist.append(tag.strip())
469
+ taglist = list(filter(lambda x: x != "", taglist))
470
+ return ", ".join(sort_taglist(taglist))
471
+
472
+
473
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
474
+ results = {
475
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
476
+ }
477
+
478
+ rating = {}
479
+ character = {}
480
+ general = {}
481
+
482
+ for k, v in results.items():
483
+ if k.startswith("rating:"):
484
+ rating[k.replace("rating:", "")] = v
485
+ continue
486
+ elif k.startswith("character:"):
487
+ character[k.replace("character:", "")] = v
488
+ continue
489
+
490
+ general[k] = v
491
+
492
+ character = {k: v for k, v in character.items() if v >= character_threshold}
493
+ general = {k: v for k, v in general.items() if v >= general_threshold}
494
+
495
+ return rating, character, general
496
+
497
+
498
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
499
+ people_tags: list[str] = []
500
+ other_tags: list[str] = []
501
+ rating_tag = RATING_MAP[rating[0]]
502
+
503
+ for tag in general:
504
+ if tag in PEOPLE_TAGS:
505
+ people_tags.append(tag)
506
+ else:
507
+ other_tags.append(tag)
508
+
509
+ all_tags = people_tags + other_tags
510
+
511
+ return ", ".join(all_tags)
512
+
513
+
514
+ @spaces.GPU(duration=30)
515
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
516
+ inputs = wd_processor.preprocess(image, return_tensors="pt")
517
+
518
+ outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
519
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
520
+
521
+ # get probabilities
522
+ if device != default_device: wd_model.to(device=device)
523
+ results = {
524
+ wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
525
+ }
526
+ if device != default_device: wd_model.to(device=default_device)
527
+ # rating, character, general
528
+ rating, character, general = postprocess_results(
529
+ results, general_threshold, character_threshold
530
+ )
531
+ prompt = gen_prompt(
532
+ list(rating.keys()), list(character.keys()), list(general.keys())
533
+ )
534
+ output_series_tag = ""
535
+ output_series_list = character_list_to_series_list(character.keys())
536
+ if output_series_list:
537
+ output_series_tag = output_series_list[0]
538
+ else:
539
+ output_series_tag = ""
540
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
541
+
542
+
543
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
544
+ character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
545
+ if not "Use WD Tagger" in algo and len(algo) != 0:
546
+ return input_series, input_character, input_tags, gr.update(interactive=True)
547
+ return predict_tags(image, general_threshold, character_threshold)
548
+
549
+
550
+ def compose_prompt_to_copy(character: str, series: str, general: str):
551
+ characters = character.split(",") if character else []
552
+ serieses = series.split(",") if series else []
553
+ generals = general.split(",") if general else []
554
+ tags = characters + serieses + generals
555
+ cprompt = ",".join(tags) if tags else ""
556
+ return cprompt
tagger/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
+
4
+
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
6
+ "ultra_wide",
7
+ "wide",
8
+ "square",
9
+ "tall",
10
+ "ultra_tall",
11
+ ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
+ "sfw",
14
+ "general",
15
+ "sensitive",
16
+ "nsfw",
17
+ "questionable",
18
+ "explicit",
19
+ ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
+ "very_short",
22
+ "short",
23
+ "medium",
24
+ "long",
25
+ "very_long",
26
+ ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
+ "none",
29
+ "lax",
30
+ "strict",
31
+ ]
32
+
33
+
34
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
+ def gradio_copy_text(_text: None):
36
+ gr.Info("Copied!")
37
+
38
+
39
+ COPY_ACTION_JS = """\
40
+ (inputs, _outputs) => {
41
+ // inputs is the string value of the input_text
42
+ if (inputs.trim() !== "") {
43
+ navigator.clipboard.writeText(inputs);
44
+ }
45
+ }"""
46
+
47
+
48
+ def gradio_copy_prompt(prompt: str):
49
+ gr.Info("Copied!")
50
+ return prompt