LittleApple_fp16 commited on
Commit
69a6cef
·
1 Parent(s): 699de72
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .bashrc +1 -0
  2. .gitattributes +0 -1
  3. .gitignore +1 -0
  4. Dockerfile +36 -0
  5. README.md +6 -5
  6. cyberharem/__init__.py +0 -0
  7. cyberharem/__pycache__/__init__.cpython-310.pyc +0 -0
  8. cyberharem/config/__init__.py +0 -0
  9. cyberharem/config/meta.py +19 -0
  10. cyberharem/dataset/__init__.py +4 -0
  11. cyberharem/dataset/__main__.py +38 -0
  12. cyberharem/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  13. cyberharem/dataset/__pycache__/crawler.cpython-310.pyc +0 -0
  14. cyberharem/dataset/__pycache__/load.cpython-310.pyc +0 -0
  15. cyberharem/dataset/__pycache__/tags.cpython-310.pyc +0 -0
  16. cyberharem/dataset/crawler.py +314 -0
  17. cyberharem/dataset/load.py +63 -0
  18. cyberharem/dataset/tags.py +250 -0
  19. cyberharem/dataset/video/__init__.py +2 -0
  20. cyberharem/dataset/video/__main__.py +58 -0
  21. cyberharem/dataset/video/__pycache__/__init__.cpython-310.pyc +0 -0
  22. cyberharem/dataset/video/__pycache__/crawler.cpython-310.pyc +0 -0
  23. cyberharem/dataset/video/__pycache__/extract.cpython-310.pyc +0 -0
  24. cyberharem/dataset/video/bangumibase.py +149 -0
  25. cyberharem/dataset/video/crawler.py +70 -0
  26. cyberharem/dataset/video/extract.py +334 -0
  27. cyberharem/infer/__init__.py +3 -0
  28. cyberharem/infer/__pycache__/__init__.cpython-310.pyc +0 -0
  29. cyberharem/infer/__pycache__/civitai.cpython-310.pyc +0 -0
  30. cyberharem/infer/__pycache__/draw.cpython-310.pyc +0 -0
  31. cyberharem/infer/__pycache__/export.cpython-310.pyc +0 -0
  32. cyberharem/infer/civitai.py +384 -0
  33. cyberharem/infer/draw.py +256 -0
  34. cyberharem/infer/export.py +101 -0
  35. cyberharem/list.py +43 -0
  36. cyberharem/publish/__init__.py +6 -0
  37. cyberharem/publish/__main__.py +158 -0
  38. cyberharem/publish/__pycache__/__init__.cpython-310.pyc +0 -0
  39. cyberharem/publish/__pycache__/__main__.cpython-310.pyc +0 -0
  40. cyberharem/publish/__pycache__/civitai.cpython-310.pyc +0 -0
  41. cyberharem/publish/__pycache__/convert.cpython-310.pyc +0 -0
  42. cyberharem/publish/__pycache__/export.cpython-310.pyc +0 -0
  43. cyberharem/publish/__pycache__/huggingface.cpython-310.pyc +0 -0
  44. cyberharem/publish/__pycache__/steps.cpython-310.pyc +0 -0
  45. cyberharem/publish/civitai.py +915 -0
  46. cyberharem/publish/convert.py +19 -0
  47. cyberharem/publish/cyberharem_publish_huggingface.py +120 -0
  48. cyberharem/publish/export.py +284 -0
  49. cyberharem/publish/huggingface.py +120 -0
  50. cyberharem/publish/steps.py +32 -0
.bashrc ADDED
@@ -0,0 +1 @@
 
 
1
+ export PATH=$HOME/.local/bin:$PATH
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8.1
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN apt-get update && \
8
+ apt-get install -y sudo tmux wget curl htop make tree && \
9
+ apt-get install -y iputils-ping telnet && \
10
+ apt-get install -y git git-lfs && \
11
+ apt-get install -y libgl1-mesa-glx
12
+
13
+ RUN --mount=type=secret,id=PASSWORD,mode=0444,required=true \
14
+ useradd -m -u 1000 user && \
15
+ echo "user:$(cat /run/secrets/PASSWORD)" | chpasswd && \
16
+ adduser user sudo
17
+
18
+ RUN pip install -U pip pysocks
19
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
20
+
21
+ USER user
22
+ ENV HOME=/home/user
23
+ ENV PATH=$HOME/.local/bin:$PATH
24
+ ENV SHELL=/bin/bash
25
+
26
+ WORKDIR $HOME
27
+
28
+ COPY --chown=user . $HOME/app
29
+
30
+ COPY .bashrc $HOME/.bashrc_append
31
+ RUN cat $HOME/.bashrc_append >> $HOME/.bashrc && \
32
+ rm $HOME/.bashrc_append
33
+
34
+ EXPOSE 7860
35
+ ENTRYPOINT []
36
+ CMD ["/bin/bash", "./app/run.sh"]
README.md CHANGED
@@ -1,11 +1,12 @@
1
  ---
2
- title: AppleJupyter
3
- emoji: 🐠
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
- license: apache-2.0
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: JupyterLab
3
+ emoji: 💹
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
+ app_port: 7860
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
cyberharem/__init__.py ADDED
File without changes
cyberharem/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
cyberharem/config/__init__.py ADDED
File without changes
cyberharem/config/meta.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Overview:
3
+ Meta information for gchar package.
4
+ """
5
+
6
+ #: Title of this project (should be `gchar`).
7
+ __TITLE__ = 'cyberharem'
8
+
9
+ #: Version of this project.
10
+ __VERSION__ = '0.0.1'
11
+
12
+ #: Short description of the project, will be included in ``setup.py``.
13
+ __DESCRIPTION__ = 'Cyber Harem of All the Waifus in Games, Mua~'
14
+
15
+ #: Author of this project.
16
+ __AUTHOR__ = 'narugo1992'
17
+
18
+ #: Email of the authors'.
19
+ __AUTHOR_EMAIL__ = 'narugo992@gmail.com'
cyberharem/dataset/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .crawler import crawl_dataset_to_huggingface, remake_dataset_to_huggingface
2
+ from .load import load_dataset_for_character
3
+ from .tags import save_recommended_tags, sort_draw_names
4
+ from .video import crawl_base_to_huggingface
cyberharem/dataset/__main__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import click
5
+ from ditk import logging
6
+ from gchar.utils import GLOBAL_CONTEXT_SETTINGS
7
+ from gchar.utils import print_version as _origin_print_version
8
+
9
+ from .tags import save_recommended_tags
10
+
11
+ print_version = partial(_origin_print_version, 'cyberharem.dataset')
12
+
13
+
14
+ @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models')
15
+ @click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True)
16
+ def cli():
17
+ pass # pragma: no cover
18
+
19
+
20
+ @cli.command('retag', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Regenerate tags for given work directory.')
21
+ @click.option('-w', '--workdir', 'workdir', type=click.Path(file_okay=False, exists=True), required=True,
22
+ help='Work directory for experiment.', show_default=True)
23
+ @click.option('-n', '--name', 'name', type=str, default=None,
24
+ help='Name of the character.', show_default=True)
25
+ def retag(workdir, name: Optional[str] = None):
26
+ logging.try_init_root(logging.INFO)
27
+
28
+ from ..publish.steps import find_steps_in_workdir
29
+ pt_name, _ = find_steps_in_workdir(workdir)
30
+ name = name or '_'.join(pt_name.split('_')[:-1])
31
+
32
+ logging.info(f'Regenerate tags for {name!r}, on {workdir!r}.')
33
+ save_recommended_tags(name, workdir=workdir)
34
+ logging.info('Success!')
35
+
36
+
37
+ if __name__ == '__main__':
38
+ cli()
cyberharem/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (429 Bytes). View file
 
cyberharem/dataset/__pycache__/crawler.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
cyberharem/dataset/__pycache__/load.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
cyberharem/dataset/__pycache__/tags.cpython-310.pyc ADDED
Binary file (8.25 kB). View file
 
cyberharem/dataset/crawler.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import json
4
+ import os.path
5
+ import zipfile
6
+ from typing import Union, Tuple, List, Optional
7
+
8
+ import pandas as pd
9
+ from ditk import logging
10
+ from gchar.games import get_character
11
+ from gchar.games.base import Character
12
+ from hbutils.string import plural_word
13
+ from hbutils.system import TemporaryDirectory
14
+ from huggingface_hub import CommitOperationAdd, hf_hub_url
15
+ from waifuc.action import NoMonochromeAction, FilterSimilarAction, \
16
+ TaggingAction, PersonSplitAction, FaceCountAction, CCIPAction, ModeConvertAction, ClassFilterAction, \
17
+ FileOrderAction, RatingFilterAction, BaseAction, RandomFilenameAction, PaddingAlignAction, ThreeStageSplitAction, \
18
+ AlignMinSizeAction, MinSizeFilterAction, FilterAction
19
+ from waifuc.action.filter import MinAreaFilterAction
20
+ from waifuc.export import SaveExporter, TextualInversionExporter
21
+ from waifuc.model import ImageItem
22
+ from waifuc.source import GcharAutoSource, BaseDataSource, LocalSource
23
+ from waifuc.utils import task_ctx
24
+
25
+ from ..utils import number_to_tag, get_ch_name, get_alphabet_name, get_hf_client, download_file, get_hf_fs
26
+
27
+
28
+ def get_source(source) -> BaseDataSource:
29
+ if isinstance(source, (str, Character)):
30
+ source = GcharAutoSource(source, main_sources_count=5)
31
+ elif isinstance(source, BaseDataSource):
32
+ pass
33
+ else:
34
+ raise TypeError(f'Unknown source type - {source!r}.')
35
+
36
+ return source
37
+
38
+
39
+ def get_main_source(source, no_r18: bool = False, bg_color: str = 'white',
40
+ no_monochrome_check: bool = False,
41
+ drop_multi: bool = True, skip: bool = False) -> BaseDataSource:
42
+ source: BaseDataSource = get_source(source)
43
+ if not skip:
44
+ actions = [ModeConvertAction('RGB', bg_color)]
45
+ if not no_monochrome_check:
46
+ actions.append(NoMonochromeAction()) # no monochrome, greyscale or sketch
47
+ actions.append(ClassFilterAction(['illustration', 'bangumi'])) # no comic or 3d
48
+ if no_r18:
49
+ actions.append(RatingFilterAction(['safe', 'r15']))
50
+
51
+ actions.append(FilterSimilarAction('all')) # filter duplicated images
52
+ if drop_multi:
53
+ actions.append(FaceCountAction(count=1, level='n')) # drop images with 0 or >1 faces
54
+ actions.extend([
55
+ PersonSplitAction(level='n'), # crop for each person
56
+ FaceCountAction(count=1, level='n'),
57
+ FileOrderAction(), # Rename files in order
58
+ CCIPAction(min_val_count=15), # CCIP, filter the character you may not want to see in dataset
59
+ FilterSimilarAction('all'), # filter duplicated images
60
+ MinSizeFilterAction(320),
61
+ TaggingAction(force=True, character_threshold=1.01),
62
+ ])
63
+ actions.append(RandomFilenameAction(ext='.png'))
64
+ else:
65
+ actions = []
66
+
67
+ return source.attach(*actions)
68
+
69
+
70
+ def actions_parse(actions: Union[int, Tuple[int, int], List[BaseAction]], bg_color: str = 'white'):
71
+ if isinstance(actions, list):
72
+ return actions
73
+ elif isinstance(actions, tuple):
74
+ width, height = actions
75
+ return [PaddingAlignAction((width, height), bg_color)]
76
+ elif isinstance(actions, int):
77
+ return [AlignMinSizeAction(actions)]
78
+ else:
79
+ raise TypeError(f'Unknown post action type - {actions!r}.')
80
+
81
+
82
+ class CustomMinSizeAction(FilterAction):
83
+ def __init__(self, main_size: int = 280, min_eye_size: int = 180):
84
+ self.main_size = main_size
85
+ self.min_eye_size = min_eye_size
86
+
87
+ def check(self, item: ImageItem) -> bool:
88
+ min_size = min(item.image.width, item.image.height)
89
+ if 'crop' in item.meta and item.meta['crop']['type'] == 'eye':
90
+ return min_size >= self.min_eye_size
91
+ else:
92
+ return min_size >= self.main_size
93
+
94
+
95
+ _SOURCES = {
96
+ 'native': [
97
+ TaggingAction(force=False, character_threshold=1.01),
98
+ ],
99
+ 'stage3': [
100
+ ThreeStageSplitAction(split_person=False),
101
+ FilterSimilarAction(),
102
+ MinSizeFilterAction(280),
103
+ TaggingAction(force=False, character_threshold=1.01),
104
+ ],
105
+ 'stage3-eyes': [
106
+ ThreeStageSplitAction(split_person=False, split_eyes=True),
107
+ FilterSimilarAction(),
108
+ CustomMinSizeAction(280, 180),
109
+ TaggingAction(force=False, character_threshold=1.01),
110
+ ]
111
+ }
112
+
113
+ _DEFAULT_RESOLUTIONS = {
114
+ 'raw': ('native', [], 'Raw data with meta information.'),
115
+ 'raw-stage3': ('stage3', [], '3-stage cropped raw data with meta information.'),
116
+ 'raw-stage3-eyes': ('stage3-eyes', [], '3-stage cropped (with eye-focus) raw data with meta information.'),
117
+ '384x512': ('native', (384, 512), '384x512 aligned dataset.'),
118
+ # '512x512': ('native', (512, 512), '512x512 aligned dataset.'),
119
+ '512x704': ('native', (512, 704), '512x704 aligned dataset.'),
120
+ # '640x640': ('native', (640, 640), '640x640 aligned dataset.'),
121
+ '640x880': ('native', (640, 880), '640x880 aligned dataset.'),
122
+ 'stage3-640': ('stage3', 640, '3-stage cropped dataset with the shorter side not exceeding 640 pixels.'),
123
+ 'stage3-800': ('stage3', 800, '3-stage cropped dataset with the shorter side not exceeding 800 pixels.'),
124
+ 'stage3-p512-640': ('stage3', [MinAreaFilterAction(512), AlignMinSizeAction(640)],
125
+ '3-stage cropped dataset with the area not less than 512x512 pixels.'),
126
+ # 'stage3-1200': ('stage3', 1200, '3-stage cropped dataset with the shorter side not exceeding 1200 pixels.'),
127
+ 'stage3-eyes-640': ('stage3-eyes', 640, '3-stage cropped (with eye-focus) dataset '
128
+ 'with the shorter side not exceeding 640 pixels.'),
129
+ 'stage3-eyes-800': ('stage3-eyes', 800, '3-stage cropped (with eye-focus) dataset '
130
+ 'with the shorter side not exceeding 800 pixels.'),
131
+ }
132
+
133
+ DATASET_PVERSION = 'v1.4'
134
+
135
+
136
+ def crawl_dataset_to_huggingface(
137
+ source: Union[str, Character, BaseDataSource], repository: Optional[str] = None,
138
+ name: Optional[str] = None, limit: Optional[int] = 1000, min_images: int = 10,
139
+ no_r18: bool = False, bg_color: str = 'white', drop_multi: bool = True, skip_preprocess: bool = False,
140
+ no_monochrome_check: bool = False,
141
+ repo_type: str = 'dataset', revision: str = 'main', path_in_repo: str = '.', private: bool = False,
142
+ ):
143
+ if isinstance(source, (str, Character)):
144
+ if isinstance(source, str):
145
+ source = get_character(source)
146
+ name = f'{source.enname} ({source.__official_name__})'
147
+
148
+ if not repository:
149
+ repository = f'AppleHarem/{get_ch_name(source)}'
150
+
151
+ else:
152
+ if name is None:
153
+ raise ValueError('Name must be specified when source is not str or character.')
154
+
155
+ if not repository:
156
+ repository = f'AppleHarem/{get_alphabet_name(name)}'
157
+
158
+ origin_source = get_main_source(source, no_r18, bg_color, no_monochrome_check, drop_multi, skip_preprocess)
159
+ with TemporaryDirectory() as td:
160
+ # save origin directory
161
+ origin_dir = os.path.join(td, 'origin')
162
+ os.makedirs(origin_dir, exist_ok=True)
163
+ if limit is not None:
164
+ origin_source = origin_source[:limit]
165
+ with task_ctx('origin'):
166
+ origin_source.export(SaveExporter(origin_dir))
167
+
168
+ img_count = len(glob.glob(os.path.join(origin_dir, '*.png')))
169
+ if img_count < min_images:
170
+ logging.warn(f'Only {plural_word(img_count, "image")} found for {name} which is too few, '
171
+ f'skip post-processing and uploading.')
172
+ return
173
+
174
+ source_dir = os.path.join(td, 'source')
175
+ os.makedirs(source_dir, exist_ok=True)
176
+ for sname, actions in _SOURCES.items():
177
+ with task_ctx(f'source/{sname}'):
178
+ LocalSource(origin_dir).attach(*actions).export(SaveExporter(os.path.join(source_dir, sname)))
179
+
180
+ processed_dir = os.path.join(td, 'processed')
181
+ os.makedirs(processed_dir, exist_ok=True)
182
+ archive_dir = os.path.join(td, 'archives')
183
+ os.makedirs(archive_dir, exist_ok=True)
184
+
185
+ files_to_upload: List[Tuple[str, str]] = []
186
+ resolutions = _DEFAULT_RESOLUTIONS
187
+
188
+ columns = ['Name', 'Images', 'Download', 'Description']
189
+ rows = []
190
+ for rname, (sname, actions, description) in resolutions.items():
191
+ actions = actions_parse(actions, bg_color)
192
+
193
+ ox = LocalSource(os.path.join(source_dir, sname))
194
+ current_processed_dir = os.path.join(processed_dir, rname)
195
+ with task_ctx(f'archive/{rname}'):
196
+ if not rname.startswith('raw'): # raw is preserved for exporting json data
197
+ ox.attach(*actions).export(TextualInversionExporter(current_processed_dir))
198
+ else:
199
+ ox.attach(*actions).export(SaveExporter(current_processed_dir))
200
+ current_img_cnt = len(glob.glob(os.path.join(current_processed_dir, '*.png')))
201
+
202
+ zip_file = os.path.join(archive_dir, f'dataset-{rname}.zip')
203
+ with zipfile.ZipFile(zip_file, mode='w') as zf:
204
+ for directory, _, files in os.walk(current_processed_dir):
205
+ for file in files:
206
+ file_path = os.path.join(directory, file)
207
+ rel_file_path = os.path.relpath(file_path, current_processed_dir)
208
+ zf.write(
209
+ file_path,
210
+ '/'.join(rel_file_path.split(os.sep))
211
+ )
212
+
213
+ rows.append((
214
+ rname,
215
+ current_img_cnt,
216
+ f'[Download]({os.path.basename(zip_file)})',
217
+ description,
218
+ ))
219
+
220
+ files_to_upload.append((zip_file, os.path.basename(zip_file)))
221
+
222
+ meta_file = os.path.join(td, 'meta.json')
223
+ with open(meta_file, 'w', encoding='utf-8') as mf:
224
+ json.dump({
225
+ 'name': name,
226
+ 'version': DATASET_PVERSION,
227
+ }, mf, indent=4, sort_keys=True, ensure_ascii=False)
228
+ files_to_upload.append((meta_file, 'meta.json'))
229
+
230
+ readme_file = os.path.join(td, 'README.md')
231
+ with open(readme_file, 'w', encoding='utf-8') as rf:
232
+ print(f'---', file=rf)
233
+ print(f'license: mit', file=rf)
234
+ print(f'task_categories:', file=rf)
235
+ print(f'- text-to-image', file=rf)
236
+ print(f'tags:', file=rf)
237
+ print(f'- art', file=rf)
238
+ print(f'- not-for-all-audiences', file=rf)
239
+ print(f'size_categories:', file=rf)
240
+ print(f'- {number_to_tag(img_count)}', file=rf)
241
+ print(f'---', file=rf)
242
+ print(f'', file=rf)
243
+
244
+ print(f'# Dataset of {name}', file=rf)
245
+ print(f'', file=rf)
246
+
247
+ print(f'This is the dataset of {name}, '
248
+ f'containing {plural_word(img_count, "images")} and their tags.', file=rf)
249
+ print(f'', file=rf)
250
+
251
+ print(f'Images are crawled from many sites (e.g. danbooru, pixiv, zerochan ...), '
252
+ f'the auto-crawling system is powered by [DeepGHS Team](https://github.com/deepghs)'
253
+ f'([huggingface organization](https://huggingface.co/deepghs)). ', file=rf)
254
+ print(f'This is a WebUI contains crawlers and other thing: '
255
+ f'([LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI))', file=rf)
256
+ print(f'', file=rf)
257
+
258
+ df = pd.DataFrame(columns=columns, data=rows)
259
+ print(df.to_markdown(index=False), file=rf)
260
+ print('', file=rf)
261
+
262
+ files_to_upload.append((readme_file, 'README.md'))
263
+
264
+ hf_client = get_hf_client()
265
+ hf_fs = get_hf_fs()
266
+ logging.info(f'Initialize repository {repository!r}')
267
+ if not hf_fs.exists(f'datasets/{repository}/.gitattributes'):
268
+ hf_client.create_repo(repo_id=repository, repo_type=repo_type, exist_ok=True, private=private)
269
+
270
+ current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z')
271
+ commit_message = f"Publish character {name}, on {current_time}"
272
+ logging.info(f'Publishing character {name!r} to repository {repository!r} ...')
273
+ hf_client.create_commit(
274
+ repository,
275
+ [
276
+ CommitOperationAdd(
277
+ path_in_repo=f'{path_in_repo}/{filename}',
278
+ path_or_fileobj=local_file,
279
+ ) for local_file, filename in files_to_upload
280
+ ],
281
+ commit_message=commit_message,
282
+ repo_type=repo_type,
283
+ revision=revision,
284
+ run_as_future=False,
285
+ )
286
+
287
+
288
+ def remake_dataset_to_huggingface(
289
+ repository: Optional[str] = None, limit: Optional[int] = 200, min_images: int = 10,
290
+ no_r18: bool = False, bg_color: str = 'white', drop_multi: bool = True,
291
+ repo_type: str = 'dataset', revision: str = 'main', path_in_repo: str = '.',
292
+ ):
293
+ hf_fs = get_hf_fs()
294
+ with TemporaryDirectory() as td:
295
+ zip_file = os.path.join(td, 'dataset-raw.zip')
296
+ download_file(hf_hub_url(repository, 'dataset-raw.zip', repo_type='dataset'), zip_file)
297
+
298
+ source_dir = os.path.join(td, 'source')
299
+ os.makedirs(source_dir, exist_ok=True)
300
+ with zipfile.ZipFile(zip_file, 'r') as zf:
301
+ zf.extractall(source_dir)
302
+
303
+ source = LocalSource(source_dir)
304
+ name = None
305
+ if hf_fs.exists(f'datasets/{repository}/meta.json'):
306
+ meta_json = json.loads(hf_fs.read_text(f'datasets/{repository}/meta.json'))
307
+ if 'name' in meta_json:
308
+ name = meta_json['name']
309
+ name = name or repository.split('/')[-1]
310
+ return crawl_dataset_to_huggingface(
311
+ source, repository, name,
312
+ limit, min_images, no_r18, bg_color, drop_multi, True,
313
+ repo_type, revision, path_in_repo
314
+ )
cyberharem/dataset/load.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path
3
+ import zipfile
4
+ from contextlib import contextmanager
5
+ from typing import ContextManager, Tuple, Optional, Union
6
+
7
+ from gchar.games import get_character
8
+ from gchar.games.base import Character
9
+ from hbutils.system import TemporaryDirectory, urlsplit
10
+ from huggingface_hub import hf_hub_url
11
+ from waifuc.utils import download_file
12
+
13
+ from ..utils import get_hf_fs, get_ch_name
14
+
15
+
16
+ @contextmanager
17
+ def load_dataset_for_character(source, size: Union[Tuple[int, int], str] = (512, 704)) \
18
+ -> ContextManager[Tuple[Optional[Character], str]]:
19
+ if isinstance(source, str) and os.path.exists(source):
20
+ if os.path.isdir(source):
21
+ logging.info(f'Dataset directory {source!r} loaded.')
22
+ yield None, source
23
+ elif os.path.isfile(source):
24
+ with zipfile.ZipFile(source, 'r') as zf, TemporaryDirectory() as td:
25
+ zf.extractall(td)
26
+ logging.info(f'Archive dataset {source!r} unzipped to {td!r} and loaded.')
27
+ yield None, td
28
+ else:
29
+ raise OSError(f'Unknown local source - {source!r}.')
30
+
31
+ else:
32
+ if isinstance(source, Character):
33
+ repo = f'AppleHarem/{get_ch_name(source)}'
34
+ else:
35
+ try_ch = get_character(source)
36
+ if try_ch is None:
37
+ repo = source
38
+ else:
39
+ source = try_ch
40
+ repo = f'AppleHarem/{get_ch_name(source)}'
41
+
42
+ hf_fs = get_hf_fs()
43
+ if isinstance(size, tuple):
44
+ width, height = size
45
+ ds_name = f'{width}x{height}'
46
+ elif isinstance(size, str):
47
+ ds_name = size
48
+ else:
49
+ raise TypeError(f'Unknown dataset type - {size!r}.')
50
+ if hf_fs.exists(f'datasets/{repo}/dataset-{ds_name}.zip'):
51
+ logging.info(f'Online dataset {repo!r} founded.')
52
+ zip_url = hf_hub_url(repo_id=repo, repo_type='dataset', filename=f'dataset-{ds_name}.zip')
53
+ with TemporaryDirectory() as dltmp:
54
+ zip_file = os.path.join(dltmp, 'dataset.zip')
55
+ download_file(zip_url, zip_file, desc=f'{repo}/{urlsplit(zip_url).filename}')
56
+
57
+ with zipfile.ZipFile(zip_file, 'r') as zf, TemporaryDirectory() as td:
58
+ zf.extractall(td)
59
+ logging.info(f'Online dataset {repo!r} loaded at {td!r}.')
60
+ yield source, td
61
+
62
+ else:
63
+ raise ValueError(f'Remote dataset {repo!r} not found for {source!r}.')
cyberharem/dataset/tags.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import random
4
+ from typing import List
5
+
6
+ from gchar.games.base import Character
7
+
8
+ from .load import load_dataset_for_character
9
+ from ..utils import load_tags_from_directory, get_ch_name, repr_tags
10
+
11
+ basic_words = [
12
+ 'best quality',
13
+ 'masterpiece',
14
+ 'highres',
15
+ ]
16
+
17
+ generic_neg_words = [
18
+ ('worst quality, low quality', 1.4), ('zombie, sketch, interlocked fingers, comic', 1.1),
19
+ ('full body', 1.1), 'lowres', 'bad anatomy', 'bad hands', 'text', 'error', 'missing fingers', 'extra digit',
20
+ 'fewer digits', 'cropped', 'worst quality', 'low quality', 'normal quality', 'jpeg artifacts', 'signature',
21
+ 'watermark', 'username', 'blurry', 'white border', ('english text, chinese text', 1.05),
22
+ ]
23
+
24
+
25
+ def _free_pos_words(generic_words, name, core_tags):
26
+ return [
27
+ *generic_words,
28
+ (name, 1.15),
29
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
30
+ ], generic_neg_words, None, True
31
+
32
+
33
+ def _bikini_pos_words(generic_words, name, core_tags):
34
+ return [
35
+ *generic_words,
36
+ ('night', 1.1),
37
+ ('starry sky', 1.1),
38
+ 'beach',
39
+ 'beautiful detailed sky',
40
+ ('extremely detailed background', 1.2),
41
+ (name, 1.15),
42
+ ('standing', 1.1),
43
+ 'looking at viewer',
44
+ ('bikini', 1.3),
45
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
46
+ 'light smile',
47
+ ], generic_neg_words, 758691538, True
48
+
49
+
50
+ def _nude_pos_words(generic_words, name, core_tags):
51
+ return [
52
+ 'nsfw',
53
+ *generic_words,
54
+ ('lying on bed', 1.1),
55
+ ('extremely detailed background', 1.2),
56
+ ('nude', 1.4),
57
+ ('spread legs', 1.1),
58
+ ('arms up', 1.1),
59
+ 'mature',
60
+ (name, 1.15),
61
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
62
+ 'nipples',
63
+ ('pussy', 1.15),
64
+ ('pussy juice', 1.3),
65
+ 'looking at viewer',
66
+ ('embarrassed', 1.1),
67
+ 'endured face',
68
+ 'feet out of frame',
69
+ ], generic_neg_words, 465191133, False
70
+
71
+
72
+ def _nude_bondage_words(generic_words, name, core_tags):
73
+ return [
74
+ 'nsfw',
75
+ *generic_words,
76
+ ('simple background', 1.1),
77
+ ('standing', 1.15),
78
+ ('nude', 1.4),
79
+ ('bondage', 1.3),
80
+ 'completely nude',
81
+ 'mature',
82
+ (name, 1.15),
83
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
84
+ 'nipples',
85
+ ('pussy', 1.15),
86
+ ('pussy juice', 1.3),
87
+ 'looking at viewer',
88
+ ('embarrassed', 1.1),
89
+ ], generic_neg_words, 758691538, False
90
+
91
+
92
+ def _nude_stand_words(generic_words, name, core_tags):
93
+ return [
94
+ 'nsfw',
95
+ *generic_words,
96
+ ('simple background', 1.1),
97
+ ('standing', 1.15),
98
+ ('nude', 1.4),
99
+ ('completely nude', 1.2),
100
+ 'mature',
101
+ (name, 1.15),
102
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
103
+ 'nipples',
104
+ ('pussy', 1.15),
105
+ ('pussy juice', 1.3),
106
+ 'looking at viewer',
107
+ ('embarrassed', 1.1),
108
+ ], generic_neg_words, 758691538, False
109
+
110
+
111
+ def _safe_maid_words(generic_words, name, core_tags):
112
+ return [
113
+ *generic_words,
114
+ ('maid', 1.4),
115
+ ('long maid dress', 1.15),
116
+ (name, 1.15),
117
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
118
+ ], [
119
+ 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet',
120
+ 'skin of legs', 'bare legs', 'bare skin', 'navel',
121
+ *generic_neg_words,
122
+ ], None, True
123
+
124
+
125
+ def _safe_yukata_words(generic_words, name, core_tags):
126
+ return [
127
+ *generic_words,
128
+ ('yukata', 1.4),
129
+ ('kimono', 1.2),
130
+ (name, 1.15),
131
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
132
+ ], [
133
+ 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet',
134
+ 'skin of legs', 'bare legs', 'bare skin', 'navel',
135
+ *generic_neg_words,
136
+ ], None, True
137
+
138
+
139
+ def _safe_miko_words(generic_words, name, core_tags):
140
+ return [
141
+ *generic_words,
142
+ ('white kimono', 1.35),
143
+ ('red hakama', 1.35),
144
+ ('wide sleeves', 1.2),
145
+ (name, 1.15),
146
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
147
+ ], [
148
+ 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet',
149
+ 'skin of legs', 'bare legs', 'bare skin', 'navel',
150
+ *generic_neg_words,
151
+ ], None, True
152
+
153
+
154
+ def _safe_suit_words(generic_words, name, core_tags):
155
+ return [
156
+ *generic_words,
157
+ ('black business suit', 1.4),
158
+ ('tie', 1.2),
159
+ ('sunglasses', 1.25),
160
+ ('white gloves', 1.15),
161
+ ('white shirt', 1.1),
162
+ ('black skirt', 1.15),
163
+ ('smoking', 1.2),
164
+ 'handsome',
165
+ (name, 1.15),
166
+ *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])],
167
+ ], [
168
+ 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet',
169
+ 'skin of legs', 'bare legs', 'bare skin', 'navel',
170
+ *generic_neg_words,
171
+ ], None, True
172
+
173
+
174
+ EXTRAS = [
175
+ ('free', _free_pos_words),
176
+ ('bikini', _bikini_pos_words),
177
+ ('maid', _safe_maid_words),
178
+ ('miko', _safe_miko_words),
179
+ ('yukata', _safe_yukata_words),
180
+ ('nude', _nude_pos_words),
181
+ ('nude2', _nude_stand_words),
182
+ ('bondage', _nude_bondage_words),
183
+ ('suit', _safe_suit_words),
184
+ ]
185
+
186
+
187
+ def save_recommended_tags(source, name: str = None, workdir: str = None, ds_size: str = '512x704'):
188
+ with load_dataset_for_character(source, ds_size) as (ch, ds_dir):
189
+ if ch is None:
190
+ if name is None:
191
+ raise ValueError(f'Name should be specified when using custom source - {source!r}.')
192
+ else:
193
+ name = name or get_ch_name(ch)
194
+
195
+ workdir = workdir or os.path.join('runs', name)
196
+ tags_dir = os.path.join(workdir, 'rtags')
197
+ os.makedirs(tags_dir, exist_ok=True)
198
+
199
+ generic_words = []
200
+ generic_words.extend(basic_words)
201
+ if isinstance(ch, Character):
202
+ if ch.gender == 'male':
203
+ generic_words.extend(['1boy', 'solo'])
204
+ elif ch.gender == 'female':
205
+ generic_words.extend(['1girl', 'solo'])
206
+ else:
207
+ generic_words.append('solo')
208
+ else:
209
+ generic_words.append('solo')
210
+
211
+ core_tags, feats = load_tags_from_directory(ds_dir)
212
+ for i, f in enumerate(feats, start=1):
213
+ pos_words = [*generic_words, (name, 1.15), *f.keys()]
214
+ pos_prompt = repr_tags(pos_words)
215
+ neg_prompt = repr_tags(generic_neg_words)
216
+
217
+ tags_name = f'pattern_{i}'
218
+ with open(os.path.join(tags_dir, f'{tags_name}.json'), 'w', encoding='utf-8') as f:
219
+ json.dump({
220
+ 'name': tags_name,
221
+ 'prompt': pos_prompt,
222
+ 'neg_prompt': neg_prompt,
223
+ 'seed': random.randint(0, 1 << 31),
224
+ 'sfw': True,
225
+ }, f, indent=4, ensure_ascii=False)
226
+
227
+ for tags_name, _func in EXTRAS:
228
+ pos_words, neg_words, seed, is_sfw = _func(generic_words, name, core_tags)
229
+ pos_prompt = repr_tags(pos_words)
230
+ neg_prompt = repr_tags(neg_words)
231
+
232
+ with open(os.path.join(tags_dir, f'{tags_name}.json'), 'w', encoding='utf-8') as f:
233
+ json.dump({
234
+ 'name': tags_name,
235
+ 'prompt': pos_prompt,
236
+ 'neg_prompt': neg_prompt,
237
+ 'seed': seed if seed is not None else random.randint(0, 1 << 31),
238
+ 'sfw': is_sfw,
239
+ }, f, indent=4, ensure_ascii=False)
240
+
241
+
242
+ def sort_draw_names(names: List[str]) -> List[str]:
243
+ vs = []
244
+ for name in names:
245
+ if name.startswith('pattern_'):
246
+ vs.append((0, int(name.split('_')[1]), name))
247
+ else:
248
+ vs.append((1, name, name))
249
+
250
+ return [item[2] for item in sorted(vs)]
cyberharem/dataset/video/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .crawler import crawl_base_to_huggingface
2
+ from .extract import extract_to_huggingface
cyberharem/dataset/video/__main__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from functools import partial
3
+
4
+ import click
5
+ from ditk import logging
6
+ from gchar.generic import import_generic
7
+ from gchar.utils import GLOBAL_CONTEXT_SETTINGS
8
+ from gchar.utils import print_version as _origin_print_version
9
+ from unidecode import unidecode
10
+
11
+ from .bangumibase import sync_bangumi_base
12
+ from .extract import extract_to_huggingface
13
+
14
+ import_generic()
15
+
16
+ print_version = partial(_origin_print_version, 'cyberharem.dataset.video')
17
+
18
+
19
+ @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish video data')
20
+ @click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True)
21
+ def cli():
22
+ pass # pragma: no cover
23
+
24
+
25
+ @cli.command('huggingface', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish to huggingface')
26
+ @click.option('--repository', '-r', 'repository', type=str, default=None,
27
+ help='Repository to publish to.', show_default=True)
28
+ @click.option('--revision', '-R', 'revision', type=str, default='main',
29
+ help='Revision for pushing the model.', show_default=True)
30
+ @click.option('--input', '-i', 'video_or_directory', type=str, required=True,
31
+ help='Input videos.', show_default=True)
32
+ @click.option('--name', '-n', 'bangumi_name', type=str, required=True,
33
+ help='Bangumi name', show_default=True)
34
+ @click.option('--min_size', '-s', 'min_size', type=int, default=320,
35
+ help='Min size of image.', show_default=True)
36
+ @click.option('--no_extract', '-E', 'no_extract', is_flag=True, type=bool, default=False,
37
+ help='No extraction from videos.', show_default=True)
38
+ def huggingface(video_or_directory: str, bangumi_name: str,
39
+ repository: str, revision: str = 'main', min_size: int = 320, no_extract: bool = False):
40
+ logging.try_init_root(logging.INFO)
41
+ rname = re.sub(r'[\W_]+', '', unidecode(bangumi_name.lower()))
42
+ repository = repository or f"BangumiBase/{rname}"
43
+ extract_to_huggingface(
44
+ video_or_directory, bangumi_name, repository, revision,
45
+ no_extract=no_extract, min_size=min_size
46
+ )
47
+
48
+
49
+ @cli.command('bgsync', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Sync index on BangumiBase')
50
+ @click.option('--repository', '-r', 'repository', type=str, default='BangumiBase/README',
51
+ help='Repository to publish to.', show_default=True)
52
+ def bgsync(repository: str):
53
+ logging.try_init_root(logging.INFO)
54
+ sync_bangumi_base(repository)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ cli()
cyberharem/dataset/video/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (260 Bytes). View file
 
cyberharem/dataset/video/__pycache__/crawler.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
cyberharem/dataset/video/__pycache__/extract.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
cyberharem/dataset/video/bangumibase.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import fnmatch
3
+ import json
4
+ import logging
5
+ import os.path
6
+ import textwrap
7
+ from typing import Tuple, Optional
8
+
9
+ import dateparser
10
+ import pandas as pd
11
+ from hbutils.string import plural_word
12
+ from hbutils.system import TemporaryDirectory
13
+ from huggingface_hub import CommitOperationAdd
14
+ from pyquery import PyQuery as pq
15
+ from tqdm.auto import tqdm
16
+
17
+ from ...utils import get_hf_client, get_hf_fs, get_requests_session, srequest, download_file
18
+
19
+ hf_client = get_hf_client()
20
+ hf_fs = get_hf_fs()
21
+
22
+
23
+ def get_animelist_info(bangumi_name) -> Tuple[Optional[str], Optional[str]]:
24
+ session = get_requests_session()
25
+ resp = srequest(
26
+ session, 'GET', 'https://myanimelist.net/anime.php',
27
+ params={
28
+ 'cat': 'anime',
29
+ 'q': bangumi_name,
30
+ }
31
+ )
32
+ table = pq(resp.text)('.js-block-list.list table')
33
+ for row in table('tr').items():
34
+ bangumi_url = row('td:nth-child(1) a').attr('href')
35
+ if not bangumi_url:
36
+ continue
37
+
38
+ r = srequest(session, 'GET', bangumi_url)
39
+ p = pq(r.text)
40
+ post_url = p("img[itemprop=image]").attr('data-src')
41
+ if bangumi_url and post_url:
42
+ return bangumi_url, post_url
43
+ else:
44
+ return None, None
45
+
46
+
47
+ def sync_bangumi_base(repository: str = 'BangumiBase/README'):
48
+ cb_models = [item.modelId for item in hf_client.list_models(author='CyberHarem')]
49
+ cb_datasets = [item.id for item in hf_client.list_datasets(author='CyberHarem')]
50
+
51
+ with TemporaryDirectory() as td:
52
+ readme_file = os.path.join(td, 'README.md')
53
+ with open(readme_file, 'w') as f:
54
+ rows, total_images, total_clusters, total_animes = [], 0, 0, 0
55
+ for item in tqdm(list(hf_client.list_datasets(author='BangumiBase'))):
56
+ if not hf_fs.exists(f'datasets/{item.id}/meta.json'):
57
+ logging.info(f'No meta information found for {item.id!r}, skipped')
58
+ continue
59
+
60
+ meta = json.loads(hf_fs.read_text(f'datasets/{item.id}/meta.json'))
61
+ bangumi_name = meta['name']
62
+ safe_bangumi_name = bangumi_name.replace('`', ' ').replace('[', '(').replace(']', ')')
63
+ suffix = item.id.split('/')[-1]
64
+ datasets_cnt = len([x for x in cb_datasets if fnmatch.fnmatch(x, f'CyberHarem/*_{suffix}')])
65
+ models_cnt = len([x for x in cb_models if fnmatch.fnmatch(x, f'CyberHarem/*_{suffix}')])
66
+
67
+ page_url, post_url = get_animelist_info(bangumi_name)
68
+ if post_url:
69
+ post_file = os.path.join(td, 'posts', f'{suffix}.jpg')
70
+ os.makedirs(os.path.dirname(post_file), exist_ok=True)
71
+ download_file(post_url, post_file)
72
+ else:
73
+ post_file = None
74
+
75
+ dataset_url = f'https://huggingface.co/datasets/{item.id}'
76
+ post_md = f'![{suffix}]({os.path.relpath(post_file, td)})' if post_file else '(no post)'
77
+ if page_url:
78
+ post_md = f'[{post_md}]({page_url})'
79
+ last_modified = dateparser.parse(item.lastModified) \
80
+ if isinstance(item.lastModified, str) else item.lastModified
81
+ rows.append({
82
+ 'Post': post_md,
83
+ 'Bangumi': f'[{safe_bangumi_name}]({dataset_url})',
84
+ 'Last Modified': last_modified.strftime('%Y-%m-%d %H:%M'),
85
+ 'Images': meta['total'],
86
+ 'Clusters': len([x for x in meta['ids'] if x != -1]),
87
+ 'Datasets': f'[{datasets_cnt}](https://huggingface.co/CyberHarem?'
88
+ f'search_models=_{suffix}&search_datasets=_{suffix})',
89
+ 'Models': f'[{models_cnt}](https://huggingface.co/CyberHarem?'
90
+ f'search_models=_{suffix}&search_datasets=_{suffix})',
91
+ })
92
+ total_images += meta['total']
93
+ total_clusters += len([x for x in meta['ids'] if x != -1])
94
+ total_animes += 1
95
+
96
+ print(textwrap.dedent(f"""
97
+ ---
98
+ title: README
99
+ emoji: 🌖
100
+ colorFrom: green
101
+ colorTo: red
102
+ sdk: static
103
+ pinned: false
104
+ ---
105
+
106
+ ## What is this?
107
+
108
+ This is a data hub utilized by the [DeepGHS team](https://huggingface.co/deepghs) for processing
109
+ anime series (in video format, including TV, OVA, movies, etc.).
110
+
111
+ After downloading anime videos to our GPU cluster, we employ various computer vision algorithms to
112
+ extract frames, crop, and **cluster them based on character features**. These processed frames are
113
+ then uploaded here to reduce the manual sorting effort required for character images.
114
+
115
+ The data in this repository will undergo automated secondary processing to remove noise,
116
+ after which it will be packaged and uploaded to [CyberHarem](https://huggingface.co/CyberHarem).
117
+ It will then be integrated into an automated pipeline for training character LoRA.
118
+
119
+ ## Current Anime Database (constantly updated)
120
+
121
+ Last updated on: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")},
122
+ contains {plural_word(total_animes, "anime")}, {plural_word(total_images, "image")}
123
+ and {plural_word(total_clusters, "cluster")} in total.
124
+ """).strip(), file=f)
125
+
126
+ rows = sorted(rows, key=lambda x: dateparser.parse(x['Last Modified']), reverse=True)
127
+ df = pd.DataFrame(rows)
128
+ print(df.to_markdown(index=False), file=f)
129
+
130
+ operations = []
131
+ for directory, _, files in os.walk(td):
132
+ for file in files:
133
+ filename = os.path.abspath(os.path.join(directory, file))
134
+ relpath = os.path.relpath(filename, td)
135
+ operations.append(CommitOperationAdd(
136
+ path_in_repo=relpath,
137
+ path_or_fileobj=filename,
138
+ ))
139
+
140
+ current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z')
141
+ commit_message = f'Update lfs images, on {current_time}'
142
+ logging.info(f'Updating lfs images to repository {repository!r} ...')
143
+ hf_client.create_commit(
144
+ repository,
145
+ operations,
146
+ commit_message=commit_message,
147
+ repo_type='space',
148
+ revision='main',
149
+ )
cyberharem/dataset/video/crawler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import logging
3
+ import os.path
4
+ import re
5
+ import zipfile
6
+ from typing import Optional, Union, List
7
+
8
+ from hbutils.system import TemporaryDirectory
9
+ from huggingface_hub import hf_hub_url
10
+ from unidecode import unidecode
11
+ from waifuc.action import CCIPAction, FilterSimilarAction, RandomFilenameAction
12
+ from waifuc.source import EmptySource, LocalSource
13
+
14
+ from ..crawler import crawl_dataset_to_huggingface
15
+ from ...utils import download_file
16
+
17
+
18
+ def crawl_base_to_huggingface(
19
+ source_repository: str, ch_id: Union[int, List[int]],
20
+ name: str, repository: Optional[str] = None,
21
+ limit: Optional[int] = 200, min_images: int = 10,
22
+ no_r18: bool = False, bg_color: str = 'white', drop_multi: bool = True,
23
+ repo_type: str = 'dataset', revision: str = 'main', path_in_repo: str = '.',
24
+ skip_preprocess: bool = True, parallel: bool = True, standalone_ccip: bool = True,
25
+ keep_cnt_ratio: bool = True,
26
+ ):
27
+ ch_ids = [ch_id] if isinstance(ch_id, int) else ch_id
28
+ source = EmptySource()
29
+ if not repository:
30
+ repository = 'CyberHarem/' + re.sub(r'[\W_]+', '_', unidecode(name.lower())).strip('_').lower() + \
31
+ '_' + source_repository.split('/')[-1]
32
+ logging.info(f'Target repository name {repository!r} will be used.')
33
+ with TemporaryDirectory() as td:
34
+ img_cnts = []
35
+ for cid in ch_ids:
36
+ url = hf_hub_url(source_repository, filename=f'{cid}/dataset.zip', repo_type='dataset')
37
+ os.makedirs(os.path.join(td, str(cid)), exist_ok=True)
38
+ zip_file = os.path.join(td, str(cid), 'dataset.zip')
39
+ download_file(url, zip_file)
40
+
41
+ source_dir = os.path.join(td, str(cid), 'source')
42
+ os.makedirs(source_dir, exist_ok=True)
43
+ with zipfile.ZipFile(zip_file, 'r') as zf:
44
+ zf.extractall(source_dir)
45
+ img_cnts.append(len(glob.glob(os.path.join(source_dir, '*.png'))))
46
+
47
+ total = sum(img_cnts)
48
+ for cid, c_cnt in zip(ch_ids, img_cnts):
49
+ source_dir = os.path.join(td, str(cid), 'source')
50
+ new_source = LocalSource(source_dir, shuffle=True)
51
+ if standalone_ccip:
52
+ new_source = new_source.attach(CCIPAction())
53
+ if keep_cnt_ratio:
54
+ new_source = new_source[:int(round(c_cnt * 1.0 / total * limit))]
55
+
56
+ if parallel:
57
+ source = source | new_source
58
+ else:
59
+ source = source + new_source
60
+ if skip_preprocess:
61
+ source = source.attach(
62
+ FilterSimilarAction('all'),
63
+ RandomFilenameAction(ext='.png'),
64
+ )
65
+
66
+ return crawl_dataset_to_huggingface(
67
+ source, repository, name,
68
+ limit, min_images, no_r18, bg_color, drop_multi, skip_preprocess,
69
+ repo_type, revision, path_in_repo
70
+ )
cyberharem/dataset/video/extract.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os.path
6
+ import random
7
+ import re
8
+ import shutil
9
+ import zipfile
10
+ from contextlib import contextmanager
11
+ from textwrap import dedent
12
+ from typing import Iterator
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ from hbutils.string import plural_word
17
+ from hbutils.system import TemporaryDirectory
18
+ from huggingface_hub import CommitOperationAdd, CommitOperationDelete
19
+ from imgutils.data import load_image
20
+ from imgutils.metrics import ccip_extract_feature, ccip_batch_differences, ccip_default_threshold
21
+ from natsort import natsorted
22
+ from sklearn.cluster import OPTICS
23
+ from tqdm.auto import tqdm
24
+ from waifuc.action import PaddingAlignAction, PersonSplitAction, FaceCountAction, MinSizeFilterAction, \
25
+ NoMonochromeAction, FilterSimilarAction, HeadCountAction, FileOrderAction, TaggingAction, RandomFilenameAction, \
26
+ BackgroundRemovalAction, ModeConvertAction, FileExtAction
27
+ from waifuc.action.filter import MinAreaFilterAction
28
+ from waifuc.export import SaveExporter, TextualInversionExporter
29
+ from waifuc.model import ImageItem
30
+ from waifuc.source import VideoSource, BaseDataSource, LocalSource, EmptySource
31
+
32
+ from ...utils import number_to_tag, get_hf_client, get_hf_fs
33
+
34
+
35
+ class ListFeatImageSource(BaseDataSource):
36
+ def __init__(self, image_files, feats):
37
+ self.image_files = image_files
38
+ self.feats = feats
39
+
40
+ def _iter(self) -> Iterator[ImageItem]:
41
+ for file, feat in zip(self.image_files, self.feats):
42
+ yield ImageItem(load_image(file), {'ccip_feature': feat, 'filename': os.path.basename(file)})
43
+
44
+
45
+ def cluster_from_directory(src_dir, dst_dir, merge_threshold: float = 0.85, clu_min_samples: int = 5,
46
+ extract_from_noise: bool = True):
47
+ image_files = np.array(natsorted(glob.glob(os.path.join(src_dir, '*.png'))))
48
+
49
+ logging.info(f'Extracting feature of {plural_word(len(image_files), "images")} ...')
50
+ images = [ccip_extract_feature(img) for img in tqdm(image_files, desc='Extract features')]
51
+ batch_diff = ccip_batch_differences(images)
52
+ batch_same = batch_diff <= ccip_default_threshold()
53
+
54
+ # clustering
55
+ def _metric(x, y):
56
+ return batch_diff[int(x), int(y)].item()
57
+
58
+ logging.info('Clustering ...')
59
+ samples = np.arange(len(images)).reshape(-1, 1)
60
+ # max_eps, _ = ccip_default_clustering_params(method='optics_best')
61
+ clustering = OPTICS(min_samples=clu_min_samples, metric=_metric).fit(samples)
62
+ labels = clustering.labels_
63
+
64
+ max_clu_id = labels.max().item()
65
+ all_label_ids = np.array([-1, *range(0, max_clu_id + 1)])
66
+ logging.info(f'Cluster complete, with {plural_word(max_clu_id, "cluster")}.')
67
+ label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0}
68
+ logging.info(f'Current label count: {label_cnt}')
69
+
70
+ if extract_from_noise:
71
+ mask_labels = labels.copy()
72
+ for nid in tqdm(np.where(labels == -1)[0], desc='Matching for noises'):
73
+ avg_dists = np.array([
74
+ batch_diff[nid][labels == i].mean()
75
+ for i in range(0, max_clu_id + 1)
76
+ ])
77
+ r_sames = np.array([
78
+ batch_same[nid][labels == i].mean()
79
+ for i in range(0, max_clu_id + 1)
80
+ ])
81
+ best_id = np.argmin(avg_dists)
82
+ if r_sames[best_id] >= 0.90:
83
+ mask_labels[nid] = best_id
84
+ labels = mask_labels
85
+ logging.info('Noise extracting complete.')
86
+ label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0}
87
+ logging.info(f'Current label count: {label_cnt}')
88
+
89
+ # trying to merge clusters
90
+ _exist_ids = set(range(0, max_clu_id + 1))
91
+ while True:
92
+ _round_merged = False
93
+ for xi in range(0, max_clu_id + 1):
94
+ if xi not in _exist_ids:
95
+ continue
96
+ for yi in range(xi + 1, max_clu_id + 1):
97
+ if yi not in _exist_ids:
98
+ continue
99
+
100
+ score = (batch_same[labels == xi][:, labels == yi]).mean()
101
+ logging.info(f'Label {xi} and {yi}\'s similarity score: {score}')
102
+ if score >= merge_threshold:
103
+ labels[labels == yi] = xi
104
+ logging.info(f'Merging label {yi} into {xi} ...')
105
+ _exist_ids.remove(yi)
106
+ _round_merged = True
107
+
108
+ if not _round_merged:
109
+ break
110
+
111
+ logging.info(f'Merge complete, remained cluster ids: {sorted(_exist_ids)}.')
112
+ label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0}
113
+ logging.info(f'Current label count: {label_cnt}')
114
+ ids = []
115
+ for i, clu_id in enumerate(tqdm(sorted(_exist_ids))):
116
+ total = (labels == clu_id).sum()
117
+ logging.info(f'Cluster {clu_id} will be renamed as #{i}, {plural_word(total, "image")} in total.')
118
+ os.makedirs(os.path.join(dst_dir, str(i)), exist_ok=True)
119
+ for imgfile in image_files[labels == clu_id]:
120
+ shutil.copyfile(imgfile, os.path.join(dst_dir, str(i), os.path.basename(imgfile)))
121
+ ids.append(i)
122
+
123
+ n_total = (labels == -1).sum()
124
+ if n_total > 0:
125
+ logging.info(f'Save noise images, {plural_word(n_total, "image")} in total.')
126
+ os.makedirs(os.path.join(dst_dir, str(-1)), exist_ok=True)
127
+ for imgfile in image_files[labels == -1]:
128
+ shutil.copyfile(imgfile, os.path.join(dst_dir, str(-1), os.path.basename(imgfile)))
129
+ ids.append(-1)
130
+
131
+ return ids
132
+
133
+
134
+ def create_project_by_result(bangumi_name: str, ids, clu_dir, dst_dir, preview_count: int = 8, regsize: int = 1000):
135
+ total_image_cnt = 0
136
+ columns = ['#', 'Images', 'Download', *(f'Preview {i}' for i in range(1, preview_count + 1))]
137
+ rows = []
138
+ reg_source = EmptySource()
139
+ for id_ in ids:
140
+ logging.info(f'Packing for #{id_} ...')
141
+ person_dir = os.path.join(dst_dir, str(id_))
142
+ new_reg_source = LocalSource(os.path.join(clu_dir, str(id_)), shuffle=True).attach(
143
+ MinAreaFilterAction(400)
144
+ )
145
+ reg_source = reg_source | new_reg_source
146
+ os.makedirs(person_dir, exist_ok=True)
147
+ with zipfile.ZipFile(os.path.join(person_dir, 'dataset.zip'), 'w') as zf:
148
+ all_person_images = glob.glob(os.path.join(clu_dir, str(id_), '*.png'))
149
+ total_image_cnt += len(all_person_images)
150
+ for file in all_person_images:
151
+ zf.write(file, os.path.basename(file))
152
+
153
+ for i, file in enumerate(random.sample(all_person_images, k=min(len(all_person_images), preview_count)),
154
+ start=1):
155
+ PaddingAlignAction((512, 704))(ImageItem(load_image(file))) \
156
+ .image.save(os.path.join(person_dir, f'preview_{i}.png'))
157
+
158
+ rel_zip_path = os.path.relpath(os.path.join(person_dir, 'dataset.zip'), dst_dir)
159
+ row = [id_ if id_ != -1 else 'noise', len(all_person_images), f'[Download]({rel_zip_path})']
160
+ for i in range(1, preview_count + 1):
161
+ if os.path.exists(os.path.join(person_dir, f'preview_{i}.png')):
162
+ relpath = os.path.relpath(os.path.join(person_dir, f'preview_{i}.png'), dst_dir)
163
+ row.append(f'![preview {i}]({relpath})')
164
+ else:
165
+ row.append('N/A')
166
+ rows.append(row)
167
+
168
+ with TemporaryDirectory() as td:
169
+ logging.info('Creating regular normal dataset ...')
170
+ reg_source.attach(
171
+ TaggingAction(force=False, character_threshold=1.01),
172
+ RandomFilenameAction(),
173
+ )[:regsize].export(TextualInversionExporter(td))
174
+
175
+ logging.info('Packing regular normal dataset ...')
176
+ reg_zip = os.path.join(dst_dir, 'regular', 'normal.zip')
177
+ os.makedirs(os.path.dirname(reg_zip), exist_ok=True)
178
+ with zipfile.ZipFile(reg_zip, 'w') as zf:
179
+ for file in glob.glob(os.path.join(td, '*')):
180
+ zf.write(file, os.path.relpath(file, td))
181
+
182
+ with TemporaryDirectory() as td_nobg:
183
+ logging.info('Creating regular no-background dataset ...')
184
+ LocalSource(td).attach(
185
+ BackgroundRemovalAction(),
186
+ ModeConvertAction('RGB', 'white'),
187
+ TaggingAction(force=True, character_threshold=1.01),
188
+ FileExtAction('.png'),
189
+ ).export(TextualInversionExporter(td_nobg))
190
+
191
+ logging.info('Packing regular no-background dataset ...')
192
+ reg_nobg_zip = os.path.join(dst_dir, 'regular', 'nobg.zip')
193
+ os.makedirs(os.path.dirname(reg_nobg_zip), exist_ok=True)
194
+ with zipfile.ZipFile(reg_nobg_zip, 'w') as zf:
195
+ for file in glob.glob(os.path.join(td_nobg, '*')):
196
+ zf.write(file, os.path.relpath(file, td_nobg))
197
+
198
+ logging.info('Packing all images ...')
199
+ all_zip = os.path.join(dst_dir, 'all.zip')
200
+ with zipfile.ZipFile(all_zip, 'w') as zf:
201
+ for file in glob.glob(os.path.join(clu_dir, '*', '*.png')):
202
+ zf.write(file, os.path.relpath(file, clu_dir))
203
+
204
+ logging.info('Packing raw package ...')
205
+ raw_zip = os.path.join(dst_dir, 'raw.zip')
206
+ with zipfile.ZipFile(raw_zip, 'w') as zf:
207
+ for file in glob.glob(os.path.join(clu_dir, '*', '*.png')):
208
+ zf.write(file, os.path.basename(file))
209
+
210
+ with open(os.path.join(dst_dir, 'meta.json'), 'w', encoding='utf-8') as f:
211
+ json.dump({
212
+ 'name': bangumi_name,
213
+ 'ids': ids,
214
+ 'total': total_image_cnt,
215
+ }, f, indent=4, sort_keys=True, ensure_ascii=False)
216
+
217
+ with open(os.path.join(dst_dir, 'README.md'), 'w', encoding='utf-8') as f:
218
+ print(dedent(f"""
219
+ ---
220
+ license: mit
221
+ tags:
222
+ - art
223
+ size_categories:
224
+ - {number_to_tag(total_image_cnt)}
225
+ ---
226
+ """).strip(), file=f)
227
+ print('', file=f)
228
+
229
+ c_name = ' '.join(map(str.capitalize, re.split(r'\s+', bangumi_name)))
230
+ print(f'# Bangumi Image Base of {c_name}', file=f)
231
+ print('', file=f)
232
+
233
+ print(f'This is the image base of bangumi {bangumi_name}, '
234
+ f'we detected {plural_word(len(ids), "character")}, '
235
+ f'{plural_word(total_image_cnt, "images")} in total. '
236
+ f'The full dataset is [here]({os.path.relpath(all_zip, dst_dir)}).', file=f)
237
+ print('', file=f)
238
+
239
+ print(f'**Please note that these image bases are not guaranteed to be 100% cleaned, '
240
+ f'they may be noisy actual.** If you intend to manually train models using this dataset, '
241
+ f'we recommend performing necessary preprocessing on the downloaded dataset to eliminate '
242
+ f'potential noisy samples (approximately 1% probability).', file=f)
243
+ print('', file=f)
244
+
245
+ print(f'Here is the characters\' preview:', file=f)
246
+ print('', file=f)
247
+
248
+ df = pd.DataFrame(columns=columns, data=rows)
249
+ print(df.to_markdown(index=False), file=f)
250
+ print('', file=f)
251
+
252
+
253
+ @contextmanager
254
+ def extract_from_videos(video_or_directory: str, bangumi_name: str, no_extract: bool = False,
255
+ min_size: int = 320, merge_threshold: float = 0.85, preview_count: int = 8):
256
+ if no_extract:
257
+ source = LocalSource(video_or_directory)
258
+ else:
259
+ if os.path.isfile(video_or_directory):
260
+ source = VideoSource(video_or_directory)
261
+ elif os.path.isdir(video_or_directory):
262
+ source = VideoSource.from_directory(video_or_directory)
263
+ else:
264
+ raise TypeError(f'Unknown video - {video_or_directory!r}.')
265
+
266
+ source = source.attach(
267
+ NoMonochromeAction(),
268
+ PersonSplitAction(keep_original=False, level='n'),
269
+ FaceCountAction(1, level='n'),
270
+ HeadCountAction(1, level='n'),
271
+ MinSizeFilterAction(min_size),
272
+ FilterSimilarAction('all'),
273
+ FileOrderAction(ext='.png'),
274
+ )
275
+
276
+ with TemporaryDirectory() as src_dir:
277
+ logging.info('Extract figures from videos ...')
278
+ source.export(SaveExporter(src_dir, no_meta=True))
279
+
280
+ with TemporaryDirectory() as clu_dir:
281
+ logging.info(f'Clustering from {src_dir!r} to {clu_dir!r} ...')
282
+ ids = cluster_from_directory(src_dir, clu_dir, merge_threshold)
283
+
284
+ with TemporaryDirectory() as dst_dir:
285
+ create_project_by_result(bangumi_name, ids, clu_dir, dst_dir, preview_count)
286
+
287
+ yield dst_dir
288
+
289
+
290
+ def extract_to_huggingface(video_or_directory: str, bangumi_name: str,
291
+ repository: str, revision: str = 'main', no_extract: bool = False,
292
+ min_size: int = 320, merge_threshold: float = 0.85, preview_count: int = 8):
293
+ logging.info(f'Initializing repository {repository!r} ...')
294
+ hf_client = get_hf_client()
295
+ hf_fs = get_hf_fs()
296
+ if not hf_fs.exists(f'datasets/{repository}/.gitattributes'):
297
+ hf_client.create_repo(repo_id=repository, repo_type='dataset', exist_ok=True)
298
+
299
+ _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')]
300
+ _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
301
+ pre_exist_files = set()
302
+ for i, (file, segments) in enumerate(_exist_ps):
303
+ if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
304
+ continue
305
+ if file != '.':
306
+ pre_exist_files.add(file)
307
+
308
+ with extract_from_videos(video_or_directory, bangumi_name, no_extract,
309
+ min_size, merge_threshold, preview_count) as dst_dir:
310
+ operations = []
311
+ for directory, _, files in os.walk(dst_dir):
312
+ for file in files:
313
+ filename = os.path.abspath(os.path.join(dst_dir, directory, file))
314
+ file_in_repo = os.path.relpath(filename, dst_dir)
315
+ operations.append(CommitOperationAdd(
316
+ path_in_repo=file_in_repo,
317
+ path_or_fileobj=filename,
318
+ ))
319
+ if file_in_repo in pre_exist_files:
320
+ pre_exist_files.remove(file_in_repo)
321
+ logging.info(f'Useless files: {sorted(pre_exist_files)} ...')
322
+ for file in sorted(pre_exist_files):
323
+ operations.append(CommitOperationDelete(path_in_repo=file))
324
+
325
+ current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z')
326
+ commit_message = f'Publish {bangumi_name}\'s data, on {current_time}'
327
+ logging.info(f'Publishing {bangumi_name}\'s data to repository {repository!r} ...')
328
+ hf_client.create_commit(
329
+ repository,
330
+ operations,
331
+ commit_message=commit_message,
332
+ repo_type='dataset',
333
+ revision=revision,
334
+ )
cyberharem/infer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .civitai import publish_samples_to_civitai, civitai_review, civitai_auto_review
2
+ from .draw import draw_images, draw_with_workdir
3
+ from .export import draw_to_directory, draw_with_repo
cyberharem/infer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (395 Bytes). View file
 
cyberharem/infer/__pycache__/civitai.cpython-310.pyc ADDED
Binary file (9.99 kB). View file
 
cyberharem/infer/__pycache__/draw.cpython-310.pyc ADDED
Binary file (7.19 kB). View file
 
cyberharem/infer/__pycache__/export.cpython-310.pyc ADDED
Binary file (3.23 kB). View file
 
cyberharem/infer/civitai.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import textwrap
8
+ from typing import Union, Optional, List
9
+
10
+ import markdown2
11
+ import numpy as np
12
+ from PIL import Image
13
+ from hbutils.string import plural_word
14
+ from hbutils.system import TemporaryDirectory
15
+ from imgutils.data import load_image
16
+ from imgutils.detect import detect_faces
17
+ from imgutils.metrics import ccip_extract_feature, ccip_batch_differences, ccip_default_threshold
18
+ from imgutils.validate import anime_rating_score
19
+ from pycivitai import civitai_find_online
20
+ from pycivitai.client import find_version_id_by_hash
21
+ from tqdm.auto import tqdm
22
+ from waifuc.source import LocalSource
23
+
24
+ from .export import draw_with_repo
25
+ from ..dataset import load_dataset_for_character
26
+ from ..publish.civitai import _tag_decode, try_find_title, try_get_title_from_repo
27
+ from ..utils import srequest, get_hf_fs, load_tags_from_directory
28
+
29
+
30
+ def publish_samples_to_civitai(images_dir, model: Union[int, str], model_version: Optional[str] = None,
31
+ model_creator='narugo1992', safe_only: bool = False,
32
+ extra_tags: Optional[List[str]] = None, post_title: str = None,
33
+ session_repo: str = 'narugo/civitai_session_p1'):
34
+ resource = civitai_find_online(model, model_version, creator=model_creator)
35
+ model_version_id = resource.version_id
36
+ post_title = post_title or f"{resource.model_name} - {resource.version_name} Review"
37
+
38
+ images = []
39
+ for img_file in glob.glob(os.path.join(images_dir, '*.png')):
40
+ img_filename = os.path.basename(img_file)
41
+ img_name = os.path.splitext(img_filename)[0]
42
+ img_info_filename = f'{img_name}_info.txt'
43
+
44
+ local_img_file = os.path.join(images_dir, img_filename)
45
+ local_info_file = os.path.join(images_dir, img_info_filename)
46
+
47
+ info = {}
48
+ with open(local_info_file, 'r', encoding='utf-8') as iif:
49
+ for line in iif:
50
+ line = line.strip()
51
+ if line:
52
+ info_name, info_text = line.split(':', maxsplit=1)
53
+ info[info_name.strip()] = info_text.strip()
54
+
55
+ meta = {
56
+ 'cfgScale': int(round(float(info.get('Guidance Scale')))),
57
+ 'negativePrompt': info.get('Neg Prompt'),
58
+ 'prompt': info.get('Prompt'),
59
+ 'sampler': info.get('Sample Method', "Euler a"),
60
+ 'seed': int(info.get('Seed')),
61
+ 'steps': int(info.get('Infer Steps')),
62
+ 'Size': f"{info['Width']}x{info['Height']}",
63
+ }
64
+ if info.get('Clip Skip'):
65
+ meta['clipSkip'] = int(info['Clip Skip'])
66
+ if info.get('Model'):
67
+ meta['Model'] = info['Model']
68
+ pil_img_file = Image.open(local_img_file)
69
+ if pil_img_file.info.get('parameters'):
70
+ png_info_text = pil_img_file.info['parameters']
71
+ find_hash = re.findall(r'Model hash:\s*([a-zA-Z\d]+)', png_info_text, re.IGNORECASE)
72
+ if find_hash:
73
+ model_hash = find_hash[0].lower()
74
+ meta['hashes'] = {"model": model_hash}
75
+ meta["resources"] = [
76
+ {
77
+ "hash": model_hash,
78
+ "name": info['Model'],
79
+ "type": "model"
80
+ }
81
+ ]
82
+ meta["Model hash"] = model_hash
83
+
84
+ nsfw = (info.get('Safe For Word', info.get('Safe For Work')) or '').lower() != 'yes'
85
+
86
+ rating_score = anime_rating_score(local_img_file)
87
+ safe_v = int(round(rating_score['safe'] * 10))
88
+ safe_r15 = int(round(rating_score['r15'] * 10))
89
+ safe_r18 = int(round(rating_score['r18'] * 10))
90
+ faces = detect_faces(local_img_file)
91
+ if faces:
92
+ (x0, y0, x1, y1), _, _ = faces[0]
93
+ width, height = load_image(local_img_file).size
94
+ face_area = abs((x1 - x0) * (y1 - y0))
95
+ face_ratio = face_area * 1.0 / (width * height)
96
+ face_ratio = int(round(face_ratio * 50))
97
+ else:
98
+ continue
99
+
100
+ images.append((
101
+ (-safe_v, -safe_r15, -safe_r18) if safe_only else (0,),
102
+ -face_ratio,
103
+ 1 if nsfw else 0,
104
+ 0 if img_name.startswith('pattern_') else 1,
105
+ img_name,
106
+ (local_img_file, img_filename, meta)
107
+ ))
108
+
109
+ images = [item[-1] for item in sorted(images)]
110
+
111
+ from ..publish.civitai import civitai_upload_images, get_civitai_session, parse_publish_at
112
+
113
+ def _custom_pc_func(mvid):
114
+ return {
115
+ "json": {
116
+ "modelVersionId": mvid,
117
+ "title": post_title,
118
+ "tag": None,
119
+ "authed": True,
120
+ },
121
+ "meta": {
122
+ "values": {
123
+ "tag": ["undefined"]
124
+ }
125
+ }
126
+ }
127
+
128
+ session = get_civitai_session(session_repo)
129
+ post_id = civitai_upload_images(
130
+ model_version_id, images,
131
+ tags=[*resource.tags, *extra_tags],
132
+ model_id=resource.model_id,
133
+ pc_func=_custom_pc_func,
134
+ session=session,
135
+ )
136
+
137
+ logging.info(f'Publishing post {post_id!r} ...')
138
+ resp = srequest(
139
+ session, 'POST', 'https://civitai.com/api/trpc/post.update',
140
+ json={
141
+ "json": {
142
+ "id": post_id,
143
+ "publishedAt": parse_publish_at('now'),
144
+ "authed": True,
145
+ },
146
+ "meta": {
147
+ "values": {
148
+ "publishedAt": ["Date"]
149
+ }
150
+ }
151
+ },
152
+ headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'},
153
+ )
154
+ resp.raise_for_status()
155
+
156
+ return images
157
+
158
+
159
+ def civitai_review(model: Union[int, str], model_version: Optional[str] = None,
160
+ model_creator='narugo1992', rating: int = 5, description_md: Optional[str] = None,
161
+ session_repo: str = 'narugo/civitai_session_p1'):
162
+ resource = civitai_find_online(model, model_version, creator=model_creator)
163
+
164
+ from ..publish.civitai import get_civitai_session
165
+ session = get_civitai_session(session_repo)
166
+
167
+ logging.info(f'Try find exist review of model version #{resource.version_id} ...')
168
+ _err = None
169
+ try: # Add this shit for the 500 of this API (2023-09-14)
170
+ resp = srequest(
171
+ session, 'GET', 'https://civitai.com/api/trpc/resourceReview.getUserResourceReview',
172
+ params={'input': json.dumps({"json": {"modelVersionId": resource.version_id, "authed": True}})},
173
+ headers={
174
+ 'Referer': f'https://civitai.com/posts/create?modelId={resource.model_id}&'
175
+ f'modelVersionId={resource.version_id}&'
176
+ f'returnUrl=/models/{resource.model_id}?'
177
+ f'modelVersionId={resource.version_id}reviewing=true'
178
+ },
179
+ raise_for_status=False
180
+ )
181
+ except AssertionError:
182
+ _err = True
183
+ resp = None
184
+
185
+ if _err or resp.status_code == 404:
186
+ logging.info(f'Creating review for #{resource.version_id} ...')
187
+ resp = srequest(
188
+ session, 'POST', 'https://civitai.com/api/trpc/resourceReview.create',
189
+ json={
190
+ "json": {
191
+ "modelVersionId": resource.version_id,
192
+ "modelId": resource.model_id,
193
+ "rating": rating,
194
+ "authed": True,
195
+ }
196
+ },
197
+ headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'}
198
+ )
199
+ resp.raise_for_status()
200
+ else:
201
+ if resp is not None:
202
+ resp.raise_for_status()
203
+ review_id = resp.json()['result']['data']['json']['id']
204
+
205
+ logging.info(f'Updating review #{review_id}\'s rating ...')
206
+ resp = srequest(
207
+ session, 'POST', 'https://civitai.com/api/trpc/resourceReview.update',
208
+ json={
209
+ "json": {
210
+ "id": review_id,
211
+ "rating": rating,
212
+ "details": None,
213
+ "authed": True,
214
+ },
215
+ "meta": {"values": {"details": ["undefined"]}}
216
+ },
217
+ headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'}
218
+ )
219
+ resp.raise_for_status()
220
+
221
+ if description_md:
222
+ logging.info(f'Updating review #{review_id}\'s description ...')
223
+ resp = srequest(
224
+ session, 'POST', 'https://civitai.com/api/trpc/resourceReview.update',
225
+ json={
226
+ "json": {
227
+ "id": review_id,
228
+ "details": markdown2.markdown(textwrap.dedent(description_md)),
229
+ 'rating': None,
230
+ "authed": True,
231
+ },
232
+ "meta": {"values": {"rating": ["undefined"]}}
233
+ },
234
+ headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'}
235
+ )
236
+ resp.raise_for_status()
237
+
238
+
239
+ _BASE_MODEL_LIST = [
240
+ 'AIARTCHAN/anidosmixV2',
241
+ # 'stablediffusionapi/anything-v5',
242
+ # 'Lykon/DreamShaper',
243
+ 'Meina/Unreal_V4.1',
244
+ 'digiplay/majicMIX_realistic_v6',
245
+ 'jzli/XXMix_9realistic-v4',
246
+ 'stablediffusionapi/abyssorangemix2nsfw',
247
+ 'AIARTCHAN/expmixLine_v2',
248
+ # 'Yntec/CuteYuki2',
249
+ 'stablediffusionapi/counterfeit-v30',
250
+ 'stablediffusionapi/flat-2d-animerge',
251
+ 'redstonehero/cetusmix_v4',
252
+ # 'KBlueLeaf/kohaku-v4-rev1.2',
253
+ # 'stablediffusionapi/night-sky-yozora-sty',
254
+ 'Meina/MeinaHentai_V4',
255
+ # 'Meina/MeinaPastel_V6',
256
+ ]
257
+
258
+
259
+ def civitai_auto_review(repository: str, model: Optional[Union[int, str]] = None,
260
+ model_version: Optional[str] = None,
261
+ model_creator='narugo1992', step: Optional[int] = None,
262
+ base_models: Optional[List[str]] = None,
263
+ rating: Optional[int] = 5, description_md: Optional[str] = None,
264
+ session_repo: str = 'narugo/civitai_session_p1'):
265
+ game_name = repository.split('/')[-1].split('_')[-1]
266
+ char_name = ' '.join(repository.split('/')[-1].split('_')[:-1])
267
+ model = model or try_find_title(char_name, game_name) or \
268
+ try_get_title_from_repo(repository) or repository.split('/')[-1]
269
+ logging.info(f'Model name on civitai: {model!r}')
270
+
271
+ from ..publish.export import KNOWN_MODEL_HASHES
272
+
273
+ hf_fs = get_hf_fs()
274
+ model_info = json.loads(hf_fs.read_text(f'{repository}/meta.json'))
275
+ dataset_info = model_info['dataset']
276
+
277
+ # load dataset
278
+ ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type']
279
+ with load_dataset_for_character(repository, size=ds_size) as (_, ds_dir):
280
+ core_tags, _ = load_tags_from_directory(ds_dir)
281
+
282
+ all_tags = [
283
+ game_name, f"{game_name} {char_name}", char_name,
284
+ 'female', 'girl', 'character', 'fully-automated', 'random prompt', 'random seed',
285
+ *map(_tag_decode, core_tags.keys()),
286
+ ]
287
+ ds_source = LocalSource(ds_dir)
288
+ ds_feats = []
289
+ for item in tqdm(list(ds_source), desc='Extract Dataset Feature'):
290
+ ds_feats.append(ccip_extract_feature(item.image))
291
+
292
+ all_feats = []
293
+ model_results = []
294
+ for base_model in (base_models or _BASE_MODEL_LIST):
295
+ logging.info(f'Reviewing with {base_model!r} ...')
296
+ with TemporaryDirectory() as td:
297
+ if KNOWN_MODEL_HASHES.get(base_model):
298
+ bm_id, bm_version_id, _ = find_version_id_by_hash(KNOWN_MODEL_HASHES[base_model])
299
+ resource = civitai_find_online(bm_id, bm_version_id)
300
+ m_name = f'{resource.model_name} - {resource.version_name}'
301
+ m_url = f'https://civitai.com/models/{resource.model_id}?modelVersionId={resource.version_id}'
302
+ else:
303
+ m_name = base_model
304
+ m_url = None
305
+
306
+ draw_with_repo(repository, td, step=step, pretrained_model=base_model)
307
+ images = publish_samples_to_civitai(
308
+ td, model, model_version,
309
+ model_creator=model_creator,
310
+ extra_tags=all_tags,
311
+ post_title=f"AI Review (Base Model: {m_name})",
312
+ session_repo=session_repo
313
+ )
314
+
315
+ images_count = len(images)
316
+ gp_feats = []
317
+ for local_imgfile, _, _ in tqdm(images, desc='Extract Images Feature'):
318
+ gp_feats.append(ccip_extract_feature(local_imgfile))
319
+ all_feats.extend(gp_feats)
320
+
321
+ gp_diffs = ccip_batch_differences([*gp_feats, *ds_feats])[:len(gp_feats), len(gp_feats):]
322
+ gp_batch = gp_diffs <= ccip_default_threshold()
323
+ scores = gp_batch.mean(axis=1)
324
+ losses = gp_diffs.mean(axis=1)
325
+
326
+ ret = {
327
+ 'model_name': m_name,
328
+ 'model_homepage': m_url,
329
+ 'images': images_count,
330
+ 'mean_score': scores.mean().item(),
331
+ 'median_score': np.median(scores).item(),
332
+ 'mean_loss': losses.mean().item(),
333
+ 'median_loss': np.median(losses).item(),
334
+ }
335
+ logging.info(f'Result of model: {ret!r}')
336
+ model_results.append(ret)
337
+
338
+ all_diffs = ccip_batch_differences([*all_feats, *ds_feats])[:len(all_feats), len(all_feats):]
339
+ all_batch = all_diffs <= ccip_default_threshold()
340
+ all_scores = all_batch.mean(axis=1)
341
+ all_losses = all_diffs.mean(axis=1)
342
+ all_mean_score = all_scores.mean().item()
343
+ all_median_score = np.median(all_scores).item()
344
+ all_mean_loss = all_losses.mean().item()
345
+ all_median_loss = np.median(all_losses).item()
346
+
347
+ if rating is not None:
348
+ logging.info('Making review ...')
349
+ with io.StringIO() as ds:
350
+ print('Tested on the following models:', file=ds)
351
+ print('', file=ds)
352
+
353
+ all_total_images = 0
354
+ for mr in model_results:
355
+ if mr['model_homepage']:
356
+ mx = f'[{mr["model_name"]}]({mr["model_homepage"]})'
357
+ else:
358
+ mx = mr['model_name']
359
+
360
+ all_total_images += mr['images']
361
+ print(
362
+ f'When using model {mx}, {plural_word(mr["images"], "image")} in total, '
363
+ f'recognition score (mean/median): {mr["mean_score"]:.3f}/{mr["median_score"]:.3f}, '
364
+ f'character image loss (mean/median): {mr["mean_loss"]:.4f}/{mr["median_loss"]:.4f}.',
365
+ file=ds
366
+ )
367
+ print('', file=ds)
368
+
369
+ print(
370
+ f'Overall, {plural_word(all_total_images, "image")} in total, '
371
+ f'recognition score (mean/median): {all_mean_score:.3f}/{all_median_score:.3f}, '
372
+ f'character image loss (mean/median): {all_mean_loss:.4f}/{all_median_loss:.4f}.',
373
+ file=ds
374
+ )
375
+ print('', file=ds)
376
+
377
+ description_md = description_md or ds.getvalue()
378
+
379
+ try:
380
+ civitai_review(model, model_version, model_creator, rating, description_md, session_repo)
381
+ except:
382
+ print('This is the description md:')
383
+ print(description_md)
384
+ raise
cyberharem/infer/draw.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ import shutil
7
+ from dataclasses import dataclass
8
+ from textwrap import dedent
9
+ from typing import List, Union, Optional
10
+
11
+ import yaml
12
+ from PIL.PngImagePlugin import PngInfo
13
+ from imgutils.detect import detect_censors
14
+
15
+ try:
16
+ from yaml import CLoader as Loader, CDumper as Dumper
17
+ except ImportError:
18
+ from yaml import Loader, Dumper
19
+ from PIL import Image
20
+ from hbutils.system import TemporaryDirectory
21
+ from hcpdiff import Visualizer
22
+ from hcpdiff.utils import load_config_with_cli
23
+
24
+ from ..utils import data_to_cli_args
25
+
26
+ _DEFAULT_INFER_CFG_FILE = 'cfgs/infer/text2img_anime_lora.yaml'
27
+ _DEFAULT_INFER_MODEL = 'LittleApple-fp16/SpiritForeseerMix'
28
+
29
+
30
+ def sample_method_to_config(method):
31
+ if method == 'DPM++ SDE Karras':
32
+ return {
33
+ '_target_': 'diffusers.DPMSolverSDEScheduler',
34
+ 'beta_start': 0.00085,
35
+ 'beta_end': 0.012,
36
+ 'beta_schedule': 'scaled_linear',
37
+ 'use_karras_sigmas': True,
38
+ }
39
+ elif method == 'DPM++ 2M Karras':
40
+ return {
41
+ '_target_': 'diffusers.DPMSolverMultistepScheduler',
42
+ 'beta_start': 0.00085,
43
+ 'beta_end': 0.012,
44
+ 'algorithm_type': 'dpmsolver++',
45
+ 'beta_schedule': 'scaled_linear',
46
+ 'use_karras_sigmas': True
47
+ }
48
+ elif method == 'Euler a':
49
+ return {
50
+ '_target_': 'diffusers.EulerAncestralDiscreteScheduler',
51
+ 'beta_start': 0.00085,
52
+ 'beta_end': 0.012,
53
+ 'beta_schedule': 'scaled_linear',
54
+ }
55
+ else:
56
+ raise ValueError(f'Unknown sample method - {method!r}.')
57
+
58
+
59
+ def draw_images(
60
+ workdir: str, prompts: Union[str, List[str]], neg_prompts: Union[str, List[str]] = None,
61
+ seeds: Union[int, List[str]] = None, emb_name: str = None, save_cfg: bool = True,
62
+ model_steps: int = 1000, n_repeats: int = 2, pretrained_model: str = _DEFAULT_INFER_MODEL,
63
+ width: int = 512, height: int = 768, gscale: float = 8, infer_steps: int = 30,
64
+ lora_alpha: float = 0.85, output_dir: str = 'output', cfg_file: str = _DEFAULT_INFER_CFG_FILE,
65
+ clip_skip: int = 2, sample_method: str = 'DPM++ 2M Karras',
66
+ ):
67
+ emb_name = emb_name or os.path.basename(workdir)
68
+ with TemporaryDirectory() as emb_dir:
69
+ src_pt_files = glob.glob(os.path.join(workdir, 'ckpts', f'*-{model_steps}.pt'))
70
+ if not src_pt_files:
71
+ raise FileNotFoundError(f'Embedding not found for step {model_steps}.')
72
+
73
+ src_pt_file = src_pt_files[0]
74
+ shutil.copyfile(src_pt_file, os.path.join(emb_dir, f'{emb_name}.pt'))
75
+
76
+ cli_args = data_to_cli_args({
77
+ 'pretrained_model': pretrained_model,
78
+ 'N_repeats': n_repeats,
79
+
80
+ 'vae_optimize': {
81
+ 'tiling': False,
82
+ },
83
+
84
+ 'clip_skip': clip_skip - 1,
85
+
86
+ 'bs': 1,
87
+ 'num': 1,
88
+
89
+ 'infer_args': {
90
+ 'width': width,
91
+ 'height': height,
92
+ 'guidance_scale': gscale,
93
+ 'num_inference_steps': infer_steps,
94
+ },
95
+
96
+ 'exp_dir': workdir,
97
+ 'model_steps': model_steps,
98
+ 'emb_dir': emb_dir,
99
+ 'output_dir': output_dir,
100
+
101
+ 'merge': {
102
+ 'alpha': lora_alpha,
103
+ },
104
+
105
+ 'new_components': {
106
+ 'scheduler': sample_method_to_config(sample_method),
107
+ 'vae': {
108
+ '_target_': 'diffusers.AutoencoderKL.from_pretrained',
109
+ 'pretrained_model_name_or_path': 'deepghs/animefull-latest', # path to vae model
110
+ 'subfolder': 'vae',
111
+ }
112
+ }
113
+ })
114
+ logging.info(f'Infer based on {cfg_file!r}, with {cli_args!r}')
115
+ cfgs = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg
116
+
117
+ N = None
118
+ if isinstance(prompts, list):
119
+ N = len(prompts)
120
+ if isinstance(neg_prompts, list):
121
+ if N is not None and len(neg_prompts) != N:
122
+ raise ValueError(f'Number of prompts ({len(prompts)}) and neg_prompts ({len(neg_prompts)}) not match.')
123
+ N = len(neg_prompts)
124
+ if isinstance(seeds, list):
125
+ if N is not None and len(seeds) != N:
126
+ raise ValueError(f'Number of both prompts ({N}) and seed ({len(seeds)}) not match.')
127
+ N = len(seeds)
128
+
129
+ if N is None:
130
+ N = 1
131
+ if not isinstance(prompts, list):
132
+ prompts = [prompts] * N
133
+ if not isinstance(neg_prompts, list):
134
+ neg_prompts = [neg_prompts] * N
135
+ if not isinstance(seeds, list):
136
+ seeds = [seeds] * N
137
+
138
+ viser = Visualizer(cfgs)
139
+ viser.vis_to_dir(prompt=prompts, negative_prompt=neg_prompts, seeds=seeds,
140
+ save_cfg=save_cfg, **cfgs.infer_args)
141
+
142
+
143
+ @dataclass
144
+ class Drawing:
145
+ name: str
146
+ prompt: str
147
+ neg_prompt: str
148
+ seed: int
149
+ sfw: bool
150
+ width: int
151
+ height: int
152
+ gscale: float
153
+ steps: int
154
+ image: Image.Image
155
+ sample_method: str
156
+ clip_skip: int
157
+ model: str
158
+ model_hash: Optional[str] = None
159
+
160
+ @property
161
+ def preview_info(self):
162
+ return dedent(f"""
163
+ Prompt: {self.prompt}
164
+ Neg Prompt: {self.neg_prompt}
165
+ Width: {self.width}
166
+ Height: {self.height}
167
+ Guidance Scale: {self.gscale}
168
+ Sample Method: {self.sample_method}
169
+ Infer Steps: {self.steps}
170
+ Clip Skip: {self.clip_skip}
171
+ Seed: {self.seed}
172
+ Model: {self.model}
173
+ Safe For Work: {"yes" if self.sfw else "no"}
174
+ """).lstrip()
175
+
176
+ @property
177
+ def pnginfo_text(self) -> str:
178
+ with io.StringIO() as sf:
179
+ print(self.prompt, file=sf)
180
+ print(f'Negative prompt: {self.neg_prompt}', file=sf)
181
+
182
+ if self.model_hash:
183
+ print(f'Steps: {self.steps}, Sampler: {self.sample_method}, '
184
+ f'CFG scale: {self.gscale}, Seed: {self.seed}, Size: {self.width}x{self.height}, '
185
+ f'Model hash: {self.model_hash.lower()}, Model: {self.model}, '
186
+ f'Clip skip: {self.clip_skip}', file=sf)
187
+ else:
188
+ print(f'Steps: {self.steps}, Sampler: {self.sample_method}, '
189
+ f'CFG scale: {self.gscale}, Seed: {self.seed}, Size: {self.width}x{self.height}, '
190
+ f'Model: {self.model}, '
191
+ f'Clip skip: {self.clip_skip}', file=sf)
192
+
193
+ return sf.getvalue()
194
+
195
+ @property
196
+ def pnginfo(self) -> PngInfo:
197
+ info = PngInfo()
198
+ info.add_text('parameters', self.pnginfo_text)
199
+ return info
200
+
201
+
202
+ _N_MAX_DRAW = 20
203
+
204
+
205
+ def draw_with_workdir(
206
+ workdir: str, emb_name: str = None, save_cfg: bool = True,
207
+ model_steps: int = 1000, n_repeats: int = 2, pretrained_model: str = _DEFAULT_INFER_MODEL,
208
+ width: int = 512, height: int = 768, gscale: float = 8, infer_steps: int = 30,
209
+ lora_alpha: float = 0.85, output_dir: str = None, cfg_file: str = _DEFAULT_INFER_CFG_FILE,
210
+ clip_skip: int = 2, sample_method: str = 'DPM++ 2M Karras', model_hash: Optional[str] = None,
211
+ ):
212
+ n_pnames, n_prompts, n_neg_prompts, n_seeds, n_sfws = [], [], [], [], []
213
+ for jfile in glob.glob(os.path.join(workdir, 'rtags', '*.json')):
214
+ with open(jfile, 'r', encoding='utf-8') as f:
215
+ data = json.load(f)
216
+ n_pnames.append(data['name'])
217
+ n_prompts.append(data['prompt'])
218
+ n_neg_prompts.append(data['neg_prompt'])
219
+ n_seeds.append(data['seed'])
220
+ n_sfws.append(data['sfw'])
221
+
222
+ n_total = len(n_pnames)
223
+ retval = []
224
+ for x in range(0, n_total, _N_MAX_DRAW):
225
+ pnames, prompts, neg_prompts, seeds, sfws = \
226
+ n_pnames[x:x + _N_MAX_DRAW], n_prompts[x:x + _N_MAX_DRAW], n_neg_prompts[x:x + _N_MAX_DRAW], \
227
+ n_seeds[x:x + _N_MAX_DRAW], n_sfws[x:x + _N_MAX_DRAW]
228
+
229
+ with TemporaryDirectory() as td:
230
+ _tmp_output_dir = output_dir or td
231
+ draw_images(
232
+ workdir, prompts, neg_prompts, seeds,
233
+ emb_name, save_cfg, model_steps, n_repeats, pretrained_model,
234
+ width, height, gscale, infer_steps, lora_alpha, _tmp_output_dir, cfg_file,
235
+ clip_skip, sample_method,
236
+ )
237
+
238
+ for i, (pname, prompt, neg_prompt, seed, sfw) in \
239
+ enumerate(zip(pnames, prompts, neg_prompts, seeds, sfws), start=1):
240
+ img_file = glob.glob(os.path.join(_tmp_output_dir, f'{i}-*.png'))[0]
241
+ yaml_file = glob.glob(os.path.join(_tmp_output_dir, f'{i}-*.yaml'))[0]
242
+ with open(yaml_file, 'r', encoding='utf-8') as f:
243
+ seed = yaml.load(f, Loader)['seed']
244
+
245
+ img = Image.open(img_file)
246
+ img.load()
247
+
248
+ retval.append(Drawing(
249
+ pname, prompt, neg_prompt, seed,
250
+ sfw=sfw and len(detect_censors(img, conf_threshold=0.45)) == 0,
251
+ width=width, height=height, gscale=gscale, steps=infer_steps,
252
+ image=img, sample_method=sample_method, clip_skip=clip_skip,
253
+ model=pretrained_model, model_hash=model_hash,
254
+ ))
255
+
256
+ return retval
cyberharem/infer/export.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from typing import Optional
5
+
6
+ from hbutils.system import TemporaryDirectory
7
+ from huggingface_hub import hf_hub_url
8
+ from tqdm.auto import tqdm
9
+
10
+ from .draw import _DEFAULT_INFER_MODEL, draw_with_workdir
11
+ from ..dataset import save_recommended_tags
12
+ from ..utils import get_hf_fs, download_file
13
+
14
+
15
+ def draw_to_directory(workdir: str, export_dir: str, step: int, n_repeats: int = 2,
16
+ pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
17
+ image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
18
+ lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
19
+ model_hash: Optional[str] = None):
20
+ from ..publish.export import KNOWN_MODEL_HASHES
21
+ model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model)
22
+ os.makedirs(export_dir, exist_ok=True)
23
+
24
+ while True:
25
+ try:
26
+ drawings = draw_with_workdir(
27
+ workdir, model_steps=step, n_repeats=n_repeats,
28
+ pretrained_model=pretrained_model,
29
+ width=image_width, height=image_height, infer_steps=infer_steps,
30
+ lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method,
31
+ model_hash=model_hash,
32
+ )
33
+ except RuntimeError:
34
+ n_repeats += 1
35
+ else:
36
+ break
37
+
38
+ all_image_files = []
39
+ for draw in drawings:
40
+ img_file = os.path.join(export_dir, f'{draw.name}.png')
41
+ draw.image.save(img_file, pnginfo=draw.pnginfo)
42
+ all_image_files.append(img_file)
43
+
44
+ with open(os.path.join(export_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f:
45
+ print(draw.preview_info, file=f)
46
+
47
+
48
+ def draw_with_repo(repository: str, export_dir: str, step: Optional[int] = None, n_repeats: int = 2,
49
+ pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
50
+ image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
51
+ lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
52
+ model_hash: Optional[str] = None):
53
+ from ..publish import find_steps_in_workdir
54
+
55
+ hf_fs = get_hf_fs()
56
+ if not hf_fs.exists(f'{repository}/meta.json'):
57
+ raise ValueError(f'Invalid repository or no model found - {repository!r}.')
58
+
59
+ logging.info(f'Model repository {repository!r} found.')
60
+ meta = json.loads(hf_fs.read_text(f'{repository}/meta.json'))
61
+ step = step or meta['best_step']
62
+ logging.info(f'Using step {step} ...')
63
+
64
+ with TemporaryDirectory() as workdir:
65
+ logging.info('Downloading models ...')
66
+ for f in tqdm(hf_fs.glob(f'{repository}/{step}/raw/*')):
67
+ rel_file = os.path.relpath(f, repository)
68
+ local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file))
69
+ if os.path.dirname(local_file):
70
+ os.makedirs(os.path.dirname(local_file), exist_ok=True)
71
+ download_file(
72
+ hf_hub_url(repository, filename=rel_file),
73
+ local_file
74
+ )
75
+
76
+ logging.info(f'Regenerating tags for {workdir!r} ...')
77
+ pt_name, _ = find_steps_in_workdir(workdir)
78
+ game_name = pt_name.split('_')[-1]
79
+ name = '_'.join(pt_name.split('_')[:-1])
80
+
81
+ from gchar.games.dispatch.access import GAME_CHARS
82
+ if game_name in GAME_CHARS:
83
+ ch_cls = GAME_CHARS[game_name]
84
+ ch = ch_cls.get(name)
85
+ else:
86
+ ch = None
87
+
88
+ if ch is None:
89
+ source = repository
90
+ else:
91
+ source = ch
92
+
93
+ logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.')
94
+ save_recommended_tags(source, name=pt_name, workdir=workdir, ds_size=meta["dataset"]['type'])
95
+
96
+ logging.info('Drawing ...')
97
+ draw_to_directory(
98
+ workdir, export_dir, step,
99
+ n_repeats, pretrained_model, clip_skip, image_width, image_height, infer_steps,
100
+ lora_alpha, sample_method, model_hash
101
+ )
cyberharem/list.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fnmatch
2
+ from functools import partial
3
+
4
+ import click
5
+ from gchar.generic import import_generic
6
+ from gchar.utils import GLOBAL_CONTEXT_SETTINGS
7
+ from gchar.utils import print_version as _origin_print_version
8
+
9
+ from cyberharem.utils import get_hf_client
10
+
11
+ print_version = partial(_origin_print_version, 'cyberharem.train')
12
+
13
+ import_generic()
14
+
15
+
16
+ @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models')
17
+ @click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True)
18
+ def cli():
19
+ pass # pragma: no cover
20
+
21
+
22
+ @cli.command('models', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='List models')
23
+ @click.option('-p', '--pattern', 'pattern', type=str, default='*',
24
+ help='Pattern of models.', show_default=True)
25
+ def models(pattern):
26
+ hf_client = get_hf_client()
27
+ for model in hf_client.list_models(author='CyberHarem'):
28
+ if fnmatch.fnmatch(model.modelId, pattern):
29
+ print(model.modelId)
30
+
31
+
32
+ @cli.command('datasets', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='List datasets')
33
+ @click.option('-p', '--pattern', 'pattern', type=str, default='*',
34
+ help='Pattern of models.', show_default=True)
35
+ def datasets(pattern):
36
+ hf_client = get_hf_client()
37
+ for ds in hf_client.list_datasets(author='CyberHarem'):
38
+ if fnmatch.fnmatch(ds.id, pattern):
39
+ print(ds.id)
40
+
41
+
42
+ if __name__ == '__main__':
43
+ cli()
cyberharem/publish/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .civitai import civitai_query_model_tags, civitai_upsert_model, civitai_query_vae_models, civitai_create_version, \
2
+ civitai_upload_models, civitai_get_model_info, civitai_upload_images, civiti_publish, civitai_publish_from_hf
3
+ from .convert import convert_to_webui_lora
4
+ from .export import export_workdir
5
+ from .huggingface import deploy_to_huggingface
6
+ from .steps import find_steps_in_workdir
cyberharem/publish/__main__.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+
4
+ import click
5
+ from ditk import logging
6
+ from gchar.generic import import_generic
7
+ from gchar.utils import GLOBAL_CONTEXT_SETTINGS
8
+ from gchar.utils import print_version as _origin_print_version
9
+ from hbutils.system import TemporaryDirectory
10
+ from huggingface_hub import hf_hub_url
11
+ from tqdm.auto import tqdm
12
+
13
+ from cyberharem.dataset import save_recommended_tags
14
+ from cyberharem.publish import find_steps_in_workdir
15
+ from cyberharem.utils import get_hf_fs, download_file
16
+ from .civitai import civitai_publish_from_hf
17
+ from .huggingface import deploy_to_huggingface
18
+ from ..infer.draw import _DEFAULT_INFER_MODEL
19
+
20
+ import_generic()
21
+
22
+ print_version = partial(_origin_print_version, 'cyberharem')
23
+
24
+
25
+ @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models')
26
+ @click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True)
27
+ def cli():
28
+ pass # pragma: no cover
29
+
30
+
31
+ @cli.command('huggingface', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish to huggingface')
32
+ @click.option('-w', '--workdir', 'workdir', type=click.Path(file_okay=False, exists=True), required=True,
33
+ help='Work directory for experiment.', show_default=True)
34
+ @click.option('--repository', '-r', 'repository', type=str, default=None,
35
+ help='Repository to publish to.', show_default=True)
36
+ @click.option('--revision', '-R', 'revision', type=str, default='main',
37
+ help='Revision for pushing the model.', show_default=True)
38
+ @click.option('-n', '--n_repeats', 'n_repeats', type=int, default=3,
39
+ help='N Repeats for text encoder', show_default=True)
40
+ @click.option('-m', '--pretrained_model', 'pretrained_model', type=str, default=_DEFAULT_INFER_MODEL,
41
+ help='Pretrained model for preview drawing.', show_default=True)
42
+ @click.option('--width', 'width', type=int, default=512,
43
+ help='Width of images.', show_default=True)
44
+ @click.option('--height', 'height', type=int, default=768,
45
+ help='Height of images.', show_default=True)
46
+ @click.option('-C', '--clip_skip', 'clip_skip', type=int, default=2,
47
+ help='Clip skip.', show_default=True)
48
+ @click.option('-S', '--infer_steps', 'infer_steps', type=int, default=30,
49
+ help='Steps of inference.', show_default=True)
50
+ def huggingface(workdir: str, repository, revision, n_repeats, pretrained_model,
51
+ width, height, clip_skip, infer_steps):
52
+ logging.try_init_root(logging.INFO)
53
+ deploy_to_huggingface(
54
+ workdir, repository, revision, n_repeats, pretrained_model,
55
+ clip_skip, width, height, infer_steps,
56
+ )
57
+
58
+
59
+ @cli.command('rehf', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Re-Publish to huggingface')
60
+ @click.option('--repository', '-r', 'repository', type=str, default=None,
61
+ help='Repository to publish to.', show_default=True)
62
+ @click.option('--revision', '-R', 'revision', type=str, default='main',
63
+ help='Revision for pushing the model.', show_default=True)
64
+ @click.option('-n', '--n_repeats', 'n_repeats', type=int, default=3,
65
+ help='N Repeats for text encoder', show_default=True)
66
+ @click.option('-m', '--pretrained_model', 'pretrained_model', type=str, default=_DEFAULT_INFER_MODEL,
67
+ help='Pretrained model for preview drawing.', show_default=True)
68
+ @click.option('--width', 'width', type=int, default=512,
69
+ help='Width of images.', show_default=True)
70
+ @click.option('--height', 'height', type=int, default=768,
71
+ help='Height of images.', show_default=True)
72
+ @click.option('-C', '--clip_skip', 'clip_skip', type=int, default=2,
73
+ help='Clip skip.', show_default=True)
74
+ @click.option('-S', '--infer_steps', 'infer_steps', type=int, default=30,
75
+ help='Steps of inference.', show_default=True)
76
+ def rehf(repository, revision, n_repeats, pretrained_model,
77
+ width, height, clip_skip, infer_steps):
78
+ logging.try_init_root(logging.INFO)
79
+ with TemporaryDirectory() as workdir:
80
+ logging.info(f'Downloading models for {workdir!r} ...')
81
+ hf_fs = get_hf_fs()
82
+ for f in tqdm(hf_fs.glob(f'{repository}/*/raw/*')):
83
+ rel_file = os.path.relpath(f, repository)
84
+ local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file))
85
+ if os.path.dirname(local_file):
86
+ os.makedirs(os.path.dirname(local_file), exist_ok=True)
87
+ download_file(
88
+ hf_hub_url(repository, filename=rel_file),
89
+ local_file
90
+ )
91
+
92
+ logging.info(f'Regenerating tags for {workdir!r} ...')
93
+ pt_name, _ = find_steps_in_workdir(workdir)
94
+ game_name = pt_name.split('_')[-1]
95
+ name = '_'.join(pt_name.split('_')[:-1])
96
+
97
+ from gchar.games.dispatch.access import GAME_CHARS
98
+ if game_name in GAME_CHARS:
99
+ ch_cls = GAME_CHARS[game_name]
100
+ ch = ch_cls.get(name)
101
+ else:
102
+ ch = None
103
+
104
+ if ch is None:
105
+ source = repository
106
+ else:
107
+ source = ch
108
+
109
+ logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.')
110
+ save_recommended_tags(source, name=pt_name, workdir=workdir)
111
+ logging.info('Success!')
112
+
113
+ deploy_to_huggingface(
114
+ workdir, repository, revision, n_repeats, pretrained_model,
115
+ clip_skip, width, height, infer_steps,
116
+ )
117
+
118
+
119
+ @cli.command('civitai', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish to huggingface')
120
+ @click.option('--repository', '-r', 'repository', type=str, required=True,
121
+ help='Repository to publish from.', show_default=True)
122
+ @click.option('--title', '-t', 'title', type=str, default=None,
123
+ help='Title of the civitai model.', show_default=True)
124
+ @click.option('--steps', '-s', 'steps', type=int, default=None,
125
+ help='Steps to deploy.', show_default=True)
126
+ @click.option('--epochs', '-e', 'epochs', type=int, default=None,
127
+ help='Epochs to deploy.', show_default=True)
128
+ @click.option('--draft', '-d', 'draft', is_flag=True, type=bool, default=False,
129
+ help='Only create draft without publishing.', show_default=True)
130
+ @click.option('--time', '-T', 'publish_time', type=str, default=None,
131
+ help='Publish time, publish immediately when not given.', show_default=True)
132
+ @click.option('--allow_nsfw', '-N', 'allow_nsfw', is_flag=True, type=bool, default=False,
133
+ help='Allow uploading nsfw images.', show_default=True)
134
+ @click.option('--version_name', '-v', 'version_name', type=str, default=None,
135
+ help='Name of the version.', show_default=True)
136
+ @click.option('--force_create', '-F', 'force_create', is_flag=True, type=bool, default=False,
137
+ help='Force create new model.', show_default=True)
138
+ @click.option('--no_ccip', 'no_ccip_check', is_flag=True, type=bool, default=False,
139
+ help='No CCIP check.', show_default=True)
140
+ def civitai(repository, title, steps, epochs, draft, publish_time, allow_nsfw,
141
+ version_name, force_create, no_ccip_check):
142
+ logging.try_init_root(logging.INFO)
143
+ model_id = civitai_publish_from_hf(
144
+ repository, title,
145
+ step=steps, epoch=epochs, draft=draft,
146
+ publish_at=publish_time, allow_nsfw_images=allow_nsfw,
147
+ version_name=version_name, force_create_model=force_create,
148
+ no_ccip_check=no_ccip_check,
149
+ )
150
+ url = f'https://civitai.com/models/{model_id}'
151
+ if not draft:
152
+ logging.info(f'Deploy success, model now can be seen at {url} .')
153
+ else:
154
+ logging.info(f'Draft created, it can be seed at {url} .')
155
+
156
+
157
+ if __name__ == '__main__':
158
+ cli()
cyberharem/publish/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (660 Bytes). View file
 
cyberharem/publish/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (5.1 kB). View file
 
cyberharem/publish/__pycache__/civitai.cpython-310.pyc ADDED
Binary file (27.9 kB). View file
 
cyberharem/publish/__pycache__/convert.cpython-310.pyc ADDED
Binary file (978 Bytes). View file
 
cyberharem/publish/__pycache__/export.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
cyberharem/publish/__pycache__/huggingface.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
cyberharem/publish/__pycache__/steps.cpython-310.pyc ADDED
Binary file (1.1 kB). View file
 
cyberharem/publish/civitai.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import logging
4
+ import math
5
+ import os.path
6
+ import re
7
+ import textwrap
8
+ import uuid
9
+ from typing import Optional, Tuple, List, Union
10
+
11
+ import blurhash
12
+ import numpy as np
13
+ from PIL import Image
14
+ from gchar.games.base import Character
15
+ from gchar.games.dispatch.access import GAME_CHARS
16
+ from gchar.generic import import_generic
17
+ from hbutils.string import plural_word
18
+ from hbutils.system import TemporaryDirectory
19
+ from huggingface_hub import hf_hub_url
20
+ from imgutils.data import load_image
21
+ from imgutils.detect import detect_faces
22
+ from imgutils.metrics import ccip_extract_feature, ccip_batch_same
23
+ from imgutils.validate import anime_rating_score, nsfw_pred
24
+ from pycivitai import civitai_find_online
25
+ from pycivitai.client import ModelNotFound
26
+ from tqdm.auto import tqdm
27
+ from urlobject import URLObject
28
+ from waifuc.source import LocalSource
29
+
30
+ try:
31
+ from typing import Literal
32
+ except (ModuleNotFoundError, ImportError):
33
+ from typing_extensions import Literal
34
+
35
+ import markdown2
36
+
37
+ from ..dataset import load_dataset_for_character
38
+ from ..utils import get_civitai_session, srequest, get_ch_name, get_hf_fs, download_file, parse_time, \
39
+ load_tags_from_directory, repr_tags
40
+
41
+ import_generic()
42
+
43
+
44
+ def _norm(x, keep_space: bool = True):
45
+ return re.sub(r'[\W_]+', ' ' if keep_space else '', x.lower()).strip()
46
+
47
+
48
+ def _model_tag_same(x, y):
49
+ return _norm(x, keep_space=True) == _norm(y, keep_space=True)
50
+
51
+
52
+ def civitai_query_model_tags(tag: str, session=None) -> Tuple[Optional[int], str]:
53
+ session = session or get_civitai_session()
54
+ logging.info(f'Querying tag {tag!r} from civitai ...')
55
+ resp = srequest(session, 'GET', 'https://civitai.com/api/trpc/tag.getAll', params={
56
+ 'input': json.dumps({
57
+ "json": {
58
+ "limit": 20,
59
+ "entityType": ["Model"],
60
+ "categories": False,
61
+ "query": tag,
62
+ "authed": True,
63
+ }
64
+ })
65
+ }, headers={'Referer': 'https://civitai.com/models/create'})
66
+
67
+ data = resp.json()['result']['data']['json']['items']
68
+ for item in data:
69
+ if _model_tag_same(item['name'], tag):
70
+ logging.info(f'Tag {item["name"]}({item["id"]}) found on civitai.')
71
+ return item['id'], item['name']
72
+ else:
73
+ logging.info(f'Tag not found on civitai, new tag {_norm(tag)!r} will be created.')
74
+ return None, _norm(tag)
75
+
76
+
77
+ CommercialUseTyping = Literal['none', 'image', 'rentCivit', 'rent', 'sell']
78
+
79
+
80
+ def civitai_upsert_model(
81
+ name, description_md: str, tags: List[str],
82
+ commercial_use: CommercialUseTyping = 'rent',
83
+ allow_no_credit: bool = True, allow_derivatives: bool = True, allow_different_licence: bool = True,
84
+ nsfw: bool = False, poi: bool = False, exist_model_id: Optional[int] = None,
85
+ session=None
86
+ ) -> Tuple[int, bool]:
87
+ session = session or get_civitai_session()
88
+ _exist_tags, tag_list, _tag_id = set(), [], 0
89
+ _meta_values = {}
90
+ for tag in tags:
91
+ tag_id, tag_name = civitai_query_model_tags(tag, session)
92
+ if tag_name not in _exist_tags:
93
+ tag_list.append({'id': tag_id, 'name': tag_name})
94
+ _meta_values[f"tagsOnModels.{_tag_id}.id"] = ["undefined"]
95
+ _tag_id += 1
96
+
97
+ post_json = {
98
+ "name": name,
99
+ "description": markdown2.markdown(textwrap.dedent(description_md)),
100
+ "type": "LORA",
101
+
102
+ "allowCommercialUse": commercial_use.lower().capitalize(), # None, Image, Rent, Sell
103
+ "allowNoCredit": allow_no_credit,
104
+ "allowDerivatives": allow_derivatives,
105
+ "allowDifferentLicense": allow_different_licence,
106
+
107
+ "nsfw": nsfw,
108
+ "poi": poi,
109
+ "tagsOnModels": tag_list,
110
+
111
+ "authed": True,
112
+ "status": "Draft",
113
+ "checkpointType": None,
114
+ "uploadType": "Created",
115
+ }
116
+ if exist_model_id:
117
+ post_json['id'] = exist_model_id
118
+ post_json["locked"] = False
119
+ post_json["status"] = "Published"
120
+ logging.info(f'Model {name!r}({exist_model_id}) already exist, updating its new information. '
121
+ f'Tags: {[item["name"] for item in tag_list]!r} ...')
122
+ else:
123
+ logging.info(f'Creating model {name!r}, tags: {[item["name"] for item in tag_list]!r} ...')
124
+
125
+ resp = session.post(
126
+ 'https://civitai.com/api/trpc/model.upsert',
127
+ json={
128
+ "json": post_json,
129
+ "meta": {
130
+ "values": _meta_values,
131
+ }
132
+ },
133
+ headers={'Referer': 'https://civitai.com/models/create'},
134
+ )
135
+
136
+ data = resp.json()['result']['data']['json']
137
+ return data['id'], data['nsfw']
138
+
139
+
140
+ def civitai_query_vae_models(session=None, model_id=None):
141
+ session = session or get_civitai_session()
142
+ logging.info('Querying VAE models ...')
143
+ resp = srequest(
144
+ session, 'GET', ' https://civitai.com/api/trpc/modelVersion.getModelVersionsByModelType',
145
+ params={'input': json.dumps({"json": {"type": "VAE", "authed": True}})},
146
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=2'}
147
+ )
148
+
149
+ data = resp.json()['result']['data']['json']
150
+ logging.info(f'{plural_word(len(data), "VAE model")} found.')
151
+ return data
152
+
153
+
154
+ def _vae_model_same(x, y):
155
+ return _norm(x, keep_space=False) == _norm(y, keep_space=False)
156
+
157
+
158
+ def civitai_create_version(
159
+ model_id: int, version_name: str, description_md: str, trigger_words: List[str],
160
+ base_model: str = 'SD 1.5', steps: Optional[int] = None, epochs: Optional[int] = None,
161
+ clip_skip: Optional[int] = 2, vae_name: Optional[str] = None, early_access_time: int = 0,
162
+ session=None
163
+ ):
164
+ session = session or get_civitai_session()
165
+
166
+ vae_id = None
167
+ if vae_name:
168
+ for vae_item in civitai_query_vae_models(session, model_id):
169
+ if _vae_model_same(vae_item['modelName'], vae_name):
170
+ vae_id = vae_item['id']
171
+
172
+ logging.info(f'Creating version {version_name!r} for model {model_id}, with base model {base_model!r} ...')
173
+ resp = srequest(
174
+ session, 'POST', 'https://civitai.com/api/trpc/modelVersion.upsert',
175
+ json={
176
+ "json": {
177
+ "modelId": model_id,
178
+ "name": version_name,
179
+ "baseModel": base_model,
180
+ "description": markdown2.markdown(textwrap.dedent(description_md)),
181
+ "steps": steps,
182
+ "epochs": epochs,
183
+ "clipSkip": clip_skip,
184
+ "vaeId": vae_id,
185
+ "trainedWords": trigger_words,
186
+ "earlyAccessTimeFrame": early_access_time,
187
+ "skipTrainedWords": bool(not trigger_words),
188
+ "authed": True,
189
+ }
190
+ },
191
+ headers={'Referer': f'https://civitai.com/models/{model_id}/wizard?step=2'}
192
+ )
193
+
194
+ return resp.json()['result']['data']['json']
195
+
196
+
197
+ def civitai_upload_file(local_file: str, type_: str = 'model', filename: str = None,
198
+ model_id: int = None, session=None):
199
+ session = session or get_civitai_session()
200
+ filename = filename or os.path.basename(local_file)
201
+
202
+ logging.info(f'Creating uploading request for {filename!r} ...')
203
+ resp = srequest(
204
+ session, 'POST', 'https://civitai.com/api/upload',
205
+ json={
206
+ "filename": filename,
207
+ "type": type_,
208
+ "size": os.path.getsize(local_file),
209
+ },
210
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'}
211
+ )
212
+ upload_data = resp.json()
213
+
214
+ logging.info(f'Uploading file {local_file!r} as {filename!r} ...')
215
+ with open(local_file, 'rb') as f:
216
+ resp = srequest(
217
+ session, 'PUT', upload_data['urls'][0]['url'], data=f,
218
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'},
219
+ )
220
+ etag = resp.headers['ETag']
221
+
222
+ logging.info(f'Completing uploading for {filename!r} ...')
223
+ resp = srequest(
224
+ session, 'POST', 'https://civitai.com/api/upload/complete',
225
+ json={
226
+ "bucket": upload_data['bucket'],
227
+ "key": upload_data['key'],
228
+ "type": type_,
229
+ "uploadId": upload_data['uploadId'],
230
+ "parts": [{"ETag": etag, "PartNumber": 1}],
231
+ },
232
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'},
233
+ )
234
+ resp.raise_for_status()
235
+
236
+ return {
237
+ "url": str(URLObject(upload_data['urls'][0]['url']).without_query()),
238
+ "bucket": upload_data['bucket'],
239
+ "key": upload_data['key'],
240
+ "name": filename,
241
+ "uuid": str(uuid.uuid4()),
242
+ "sizeKB": os.path.getsize(local_file) / 1024.0,
243
+ }
244
+
245
+
246
+ def civitai_upload_models(model_version_id: int, model_files: List[Union[str, Tuple[str, str]]],
247
+ model_id: int = None, session=None):
248
+ session = session or get_civitai_session()
249
+ file_items = []
250
+ for file_item in model_files:
251
+ if isinstance(file_item, str):
252
+ local_file, filename = file_item, file_item
253
+ elif isinstance(file_item, tuple):
254
+ local_file, filename = file_item
255
+ else:
256
+ raise TypeError(f'Unknown file type - {file_item!r}.')
257
+ file_items.append((local_file, filename))
258
+
259
+ for local_file, filename in file_items:
260
+ upload_data = civitai_upload_file(local_file, 'model', filename, model_id, session)
261
+ logging.info(f'Creating {filename!r} as model file of version {model_version_id} ...')
262
+ resp = srequest(
263
+ session, 'POST', 'https://civitai.com/api/trpc/modelFile.create',
264
+ json={
265
+ 'json': {
266
+ **upload_data,
267
+ "modelVersionId": model_version_id,
268
+ "type": "Model",
269
+ "metadata": {
270
+ "size": None,
271
+ "fp": None
272
+ },
273
+ "authed": True
274
+ },
275
+ },
276
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'},
277
+ )
278
+ resp.raise_for_status()
279
+
280
+
281
+ def civitai_get_model_info(model_id: int, session=None):
282
+ session = session or get_civitai_session()
283
+ resp = srequest(
284
+ session, 'GET', 'https://civitai.com/api/trpc/model.getById',
285
+ params={'input': json.dumps({"json": {"id": model_id, "authed": True}})},
286
+ headers={'Referer': f'https://civitai.com/models/{model_id}/wizard?step=4'},
287
+ )
288
+ return resp.json()['result']['data']['json']
289
+
290
+
291
+ def get_clamped_size(width, height, max_val, _type='all'):
292
+ if _type == 'all':
293
+ if width >= height:
294
+ _type = 'width'
295
+ elif height >= width:
296
+ _type = 'height'
297
+
298
+ if _type == 'width' and width > max_val:
299
+ return max_val, int(round((height / width) * max_val))
300
+
301
+ if _type == 'height' and height > max_val:
302
+ return int(round((width / height) * max_val)), max_val
303
+
304
+ return width, height
305
+
306
+
307
+ def parse_publish_at(publish_at: Optional[str] = None, keep_none: bool = True) -> Optional[str]:
308
+ try:
309
+ from zoneinfo import ZoneInfo
310
+ except (ImportError, ModuleNotFoundError):
311
+ from backports.zoneinfo import ZoneInfo
312
+
313
+ if not keep_none and publish_at is None:
314
+ publish_at = 'now'
315
+ if publish_at is not None:
316
+ local_time = parse_time(publish_at)
317
+ publish_at = local_time.astimezone(ZoneInfo('UTC')).isoformat()
318
+
319
+ return publish_at
320
+
321
+
322
+ def _post_create_func(model_version_id):
323
+ return {
324
+ "json": {
325
+ "modelVersionId": model_version_id,
326
+ "authed": True,
327
+ }
328
+ }
329
+
330
+
331
+ def civitai_upload_images(
332
+ model_version_id: int, image_files: List[Union[str, Tuple[str, str], Tuple[str, str, dict]]],
333
+ tags: List[str], nsfw: bool = False, model_id: int = None, pc_func=_post_create_func, session=None
334
+ ):
335
+ session = session or get_civitai_session()
336
+
337
+ image_items = []
338
+ for image_item in image_files:
339
+ if isinstance(image_item, str):
340
+ local_file, filename, meta = image_item, image_item, {}
341
+ elif isinstance(image_item, tuple):
342
+ if len(image_item) == 2:
343
+ (local_file, filename), meta = image_item, {}
344
+ elif len(image_item) == 3:
345
+ local_file, filename, meta = image_item
346
+ else:
347
+ raise ValueError(f'Invalid image file format - {image_item!r}.')
348
+ else:
349
+ raise TypeError(f'Invalid image file type - {image_item!r}.')
350
+ image_items.append((local_file, filename, meta))
351
+
352
+ logging.info(f'Creating post for model version {model_version_id} ...')
353
+ resp = srequest(
354
+ session, 'POST', 'https://civitai.com/api/trpc/post.create',
355
+ json=pc_func(model_version_id),
356
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
357
+ )
358
+ post_id = resp.json()['result']['data']['json']['id']
359
+
360
+ for index, (local_file, filename, meta) in enumerate(image_items):
361
+ logging.info(f'Creating image uploading request for image {filename!r} ...')
362
+ resp = srequest(
363
+ session, 'POST', 'https://civitai.com/api/image-upload',
364
+ json={
365
+ "filename": filename,
366
+ "metadata": {}
367
+ },
368
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
369
+ )
370
+ upload_id = resp.json()['id']
371
+ upload_url = resp.json()['uploadURL']
372
+
373
+ logging.info(f'Uploading local image {local_file!r} as image {filename!r} ...')
374
+ with open(local_file, 'rb') as f:
375
+ resp = srequest(session, 'PUT', upload_url, data=f)
376
+ resp.raise_for_status()
377
+
378
+ img = load_image(local_file, force_background='white', mode='RGB')
379
+ new_width, new_height = get_clamped_size(img.width, img.height, 32)
380
+ bhash = blurhash.encode(np.array(img.resize((new_width, new_height))))
381
+ logging.info(f'Completing the uploading of {filename!r} ...')
382
+ resp = srequest(
383
+ session, 'POST', 'https://civitai.com/api/trpc/post.addImage',
384
+ json={
385
+ "json": {
386
+ "type": "image",
387
+ "index": index,
388
+ "uuid": str(uuid.uuid4()),
389
+ "name": filename,
390
+ "meta": meta,
391
+ "url": upload_id,
392
+ "mimeType": "image/png",
393
+ "hash": bhash,
394
+ "width": img.width,
395
+ "height": img.height,
396
+ "status": "uploading",
397
+ "message": None,
398
+ "postId": post_id,
399
+ "modelVersionId": model_version_id,
400
+ "authed": True
401
+ },
402
+ "meta": {
403
+ "values": {
404
+ "message": [
405
+ "undefined"
406
+ ]
407
+ }
408
+ }
409
+ },
410
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
411
+ )
412
+ resp.raise_for_status()
413
+
414
+ for tag in tags:
415
+ tag_id, tag_name = civitai_query_model_tags(tag, session)
416
+ if tag_id is not None:
417
+ logging.info(f'Adding tag {tag_name!r}({tag_id}) for post {post_id!r} ...')
418
+ resp = srequest(
419
+ session, 'POST', 'https://civitai.com/api/trpc/post.addTag',
420
+ json={
421
+ "json": {
422
+ "id": post_id,
423
+ "tagId": tag_id,
424
+ "name": tag_name,
425
+ "authed": True,
426
+ }
427
+ },
428
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
429
+ )
430
+ else:
431
+ logging.info(f'Creating and adding new tag {tag_name!r} for post {post_id!r} ...')
432
+ resp = srequest(
433
+ session, 'POST', 'https://civitai.com/api/trpc/post.addTag',
434
+ json={
435
+ "json": {
436
+ "id": post_id,
437
+ "tagId": None,
438
+ "name": tag_name,
439
+ "authed": True,
440
+ },
441
+ "meta": {
442
+ "values": {
443
+ "tagId": ["undefined"]
444
+ }
445
+ }
446
+ },
447
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
448
+ )
449
+
450
+ resp.raise_for_status()
451
+
452
+ logging.info(f'Marking for nsfw ({nsfw!r}) ...')
453
+ resp = srequest(
454
+ session, 'POST', 'https://civitai.com/api/trpc/post.update',
455
+ json={
456
+ 'json': {
457
+ 'id': post_id,
458
+ 'nsfw': nsfw,
459
+ 'authed': True,
460
+ }
461
+ },
462
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
463
+ )
464
+ resp.raise_for_status()
465
+
466
+ return post_id
467
+
468
+
469
+ def civiti_publish(model_id: int, model_version_id: int, publish_at=None, session=None):
470
+ session = session or get_civitai_session()
471
+ publish_at = parse_publish_at(publish_at, keep_none=True)
472
+
473
+ if publish_at:
474
+ logging.info(f'Publishing model {model_id!r}\'s version {model_version_id!r}, at {publish_at!r} ...')
475
+ else:
476
+ logging.info(f'Publishing model {model_id!r}\'s version {model_version_id!r} ...')
477
+ resp = srequest(
478
+ session, 'POST', 'https://civitai.com/api/trpc/model.publish',
479
+ json={
480
+ "json": {
481
+ "id": model_id,
482
+ "versionIds": [
483
+ model_version_id
484
+ ],
485
+ "publishedAt": publish_at,
486
+ "authed": True
487
+ },
488
+ "meta": {
489
+ "values": {
490
+ "publishedAt": [
491
+ "undefined" if publish_at is None else "Date",
492
+ ]
493
+ }
494
+ }
495
+ },
496
+ headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'},
497
+ )
498
+ resp.raise_for_status()
499
+
500
+
501
+ def try_find_title(char_name, game_name):
502
+ try:
503
+ game_cls = GAME_CHARS[game_name.lower()]
504
+ ch = game_cls.get(char_name)
505
+ if ch:
506
+ names = []
507
+ if ch.enname:
508
+ names.append(str(ch.enname))
509
+ if ch.jpname:
510
+ names.append(str(ch.jpname))
511
+ if ch.cnname:
512
+ names.append(str(ch.cnname))
513
+ if hasattr(ch, 'krname') and ch.krname:
514
+ names.append(str(ch.krname))
515
+
516
+ return f"{'/'.join(names)} ({game_cls.__official_name__})"
517
+
518
+ else:
519
+ cname = ' '.join(list(map(str.capitalize, char_name.split(' '))))
520
+ return f'{cname} ({game_cls.__official_name__})'
521
+
522
+ except KeyError:
523
+ return None
524
+
525
+
526
+ def try_get_title_from_repo(repo):
527
+ hf_fs = get_hf_fs()
528
+ print(f'datasets/{repo}/meta.json')
529
+ if hf_fs.exists(f'datasets/{repo}/meta.json'):
530
+ data = json.loads(hf_fs.read_text(f'datasets/{repo}/meta.json'))
531
+ character_name = data['name']
532
+
533
+ source_name = repo.split('_')[-1]
534
+ if hf_fs.exists(f'datasets/BangumiBase/{source_name}/meta.json'):
535
+ base_data = json.loads(hf_fs.read_text(f'datasets/BangumiBase/{source_name}/meta.json'))
536
+ source_full_name = base_data['name']
537
+ return f'{character_name} ({source_full_name})'
538
+ else:
539
+ return character_name
540
+ else:
541
+ return None
542
+
543
+
544
+ def _tag_decode(text):
545
+ return re.sub(r'[\s_]+', ' ', re.sub(r'\\([\\()])', r'\1', text)).strip()
546
+
547
+
548
+ def civitai_publish_from_hf(source, model_name: str = None, model_desc_md: str = None,
549
+ version_name: Optional[str] = None, version_desc_md: str = None,
550
+ step: Optional[int] = None, epoch: Optional[int] = None, upload_min_epoch: int = 6,
551
+ draft: bool = False, publish_at=None, allow_nsfw_images: bool = True,
552
+ force_create_model: bool = False, no_ccip_check: bool = False, session=None):
553
+ if isinstance(source, Character):
554
+ repo = f'AppleHarem/{get_ch_name(source)}'
555
+ elif isinstance(source, str):
556
+ repo = source
557
+ else:
558
+ raise TypeError(f'Unknown source type - {source!r}.')
559
+ hf_fs = get_hf_fs()
560
+ meta_json = json.loads(hf_fs.read_text(f'{repo}/meta.json'))
561
+ game_name = repo.split('_')[-1]
562
+
563
+ dataset_info = meta_json.get('dataset')
564
+ ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type']
565
+ with load_dataset_for_character(repo, size=ds_size) as (_, d):
566
+ if dataset_info and dataset_info['size']:
567
+ dataset_size = dataset_info['size']
568
+ else:
569
+ dataset_size = len(glob.glob(os.path.join(d, '*.png')))
570
+ core_tags, _ = load_tags_from_directory(d)
571
+ logging.info(f'Size of dataset if {dataset_size!r}.')
572
+
573
+ ccip_feats = []
574
+ for item in tqdm(list(LocalSource(d)[:10]), desc='Extracting features'):
575
+ ccip_feats.append(ccip_extract_feature(item.image))
576
+
577
+ version_name = version_name or meta_json.get('mark') or 'v1.0'
578
+ all_steps = meta_json['steps']
579
+ logging.info(f'Available steps: {all_steps!r}.')
580
+ if step is not None:
581
+ if epoch is not None:
582
+ logging.warning(f'Step {step!r} is set, epoch value ({epoch}) will be ignored.')
583
+ else:
584
+ if epoch is not None:
585
+ step = dataset_size * epoch
586
+ else:
587
+ if 'best_step' in meta_json:
588
+ if upload_min_epoch is not None:
589
+ upload_min_step = upload_min_epoch * dataset_size
590
+ else:
591
+ upload_min_step = -1
592
+ best_step, best_score = None, None
593
+ for score_item in meta_json["scores"]:
594
+ if best_step is None or \
595
+ (score_item['step'] >= upload_min_step and score_item['score'] >= best_score):
596
+ best_step, best_score = score_item['step'], score_item['score']
597
+
598
+ if best_step is not None:
599
+ step = best_step
600
+ else:
601
+ step = meta_json['best_step']
602
+ else:
603
+ step = max(all_steps)
604
+
605
+ logging.info(f'Expected step is {step!r}.')
606
+ _, _actual_step = sorted([(abs(s - step), s) for s in all_steps])[0]
607
+ if _actual_step != step:
608
+ logging.info(f'Actual used step is {_actual_step!r}.')
609
+
610
+ step = _actual_step
611
+ epoch = int(math.ceil(step / dataset_size))
612
+ logging.info(f'Using step {step}, epoch {epoch}.')
613
+
614
+ with TemporaryDirectory() as td:
615
+ models_dir = os.path.join(td, 'models')
616
+ os.makedirs(models_dir, exist_ok=True)
617
+
618
+ lora_file = os.path.basename(hf_fs.glob(f'{repo}/{step}/*.safetensors')[0])
619
+ pt_file = os.path.basename(hf_fs.glob(f'{repo}/{step}/*.pt')[0])
620
+ trigger_word = os.path.splitext(lora_file)[0]
621
+ char_name = ' '.join(trigger_word.split('_')[:-1])
622
+
623
+ models = []
624
+ local_lora_file = os.path.join(models_dir, lora_file)
625
+ download_file(hf_hub_url(repo, filename=f'{step}/{lora_file}'), local_lora_file)
626
+ models.append((local_lora_file, lora_file))
627
+ local_pt_file = os.path.join(models_dir, pt_file)
628
+ download_file(hf_hub_url(repo, filename=f'{step}/{pt_file}'), local_pt_file)
629
+ models.append((local_pt_file, pt_file))
630
+
631
+ images_dir = os.path.join(td, 'images')
632
+ os.makedirs(images_dir, exist_ok=True)
633
+
634
+ images = []
635
+ tags_count = {}
636
+ tags_idx = {}
637
+ for img_file in hf_fs.glob(f'{repo}/{step}/previews/*.png'):
638
+ img_filename = os.path.basename(img_file)
639
+ img_name = os.path.splitext(img_filename)[0]
640
+ img_info_filename = f'{img_name}_info.txt'
641
+
642
+ local_img_file = os.path.join(images_dir, img_filename)
643
+ download_file(hf_hub_url(repo, filename=f'{step}/previews/{img_filename}'), local_img_file)
644
+ local_info_file = os.path.join(images_dir, img_info_filename)
645
+ download_file(hf_hub_url(repo, filename=f'{step}/previews/{img_info_filename}'), local_info_file)
646
+
647
+ info = {}
648
+ with open(local_info_file, 'r', encoding='utf-8') as iif:
649
+ for line in iif:
650
+ line = line.strip()
651
+ if line:
652
+ info_name, info_text = line.split(':', maxsplit=1)
653
+ info[info_name.strip()] = info_text.strip()
654
+
655
+ meta = {
656
+ 'cfgScale': int(round(float(info.get('Guidance Scale')))),
657
+ 'negativePrompt': info.get('Neg Prompt'),
658
+ 'prompt': info.get('Prompt'),
659
+ 'sampler': info.get('Sample Method', "Euler a"),
660
+ 'seed': int(info.get('Seed')),
661
+ 'steps': int(info.get('Infer Steps')),
662
+ 'Size': f"{info['Width']}x{info['Height']}",
663
+ }
664
+ if info.get('Clip Skip'):
665
+ meta['clipSkip'] = int(info['Clip Skip'])
666
+ if info.get('Model'):
667
+ meta['Model'] = info['Model']
668
+ pil_img_file = Image.open(local_img_file)
669
+ if pil_img_file.info.get('parameters'):
670
+ png_info_text = pil_img_file.info['parameters']
671
+ find_hash = re.findall(r'Model hash:\s*([a-zA-Z\d]+)', png_info_text, re.IGNORECASE)
672
+ if find_hash:
673
+ model_hash = find_hash[0].lower()
674
+ meta['hashes'] = {"model": model_hash}
675
+ meta["resources"] = [
676
+ {
677
+ "hash": model_hash,
678
+ "name": info['Model'],
679
+ "type": "model"
680
+ }
681
+ ]
682
+ meta["Model hash"] = model_hash
683
+
684
+ nsfw = (info.get('Safe For Word', info.get('Safe For Work')) or '').lower() != 'yes'
685
+ if not nsfw:
686
+ cls_, score_ = nsfw_pred(local_img_file)
687
+ if cls_ not in {'hentai', 'porn', 'sexy'} and score_ >= 0.65:
688
+ pass
689
+ else:
690
+ nsfw = True
691
+
692
+ if nsfw and not allow_nsfw_images:
693
+ logging.info(f'Image {local_img_file!r} skipped due to its nsfw.')
694
+ continue
695
+
696
+ current_feat = ccip_extract_feature(local_img_file)
697
+ similarity = ccip_batch_same([current_feat, *ccip_feats])[0, 1:].mean()
698
+ logging.info(f'Similarity of character on image {local_img_file!r}: {similarity!r}')
699
+ if similarity < 0.6 and not no_ccip_check:
700
+ logging.info(f'Similarity of {local_img_file!r}({similarity!r}) is too low, skipped.')
701
+ continue
702
+
703
+ if not nsfw or allow_nsfw_images:
704
+ rating_score = anime_rating_score(local_img_file)
705
+ safe_v = int(round(rating_score['safe'] * 10))
706
+ safe_r15 = int(round(rating_score['r15'] * 10))
707
+ safe_r18 = int(round(rating_score['r18'] * 10))
708
+ faces = detect_faces(local_img_file)
709
+ if faces:
710
+ if len(faces) > 1:
711
+ logging.warning('Multiple face detected, skipped!')
712
+ continue
713
+
714
+ (x0, y0, x1, y1), _, _ = faces[0]
715
+ width, height = load_image(local_img_file).size
716
+ face_area = abs((x1 - x0) * (y1 - y0))
717
+ face_ratio = face_area * 1.0 / (width * height)
718
+ face_ratio = int(round(face_ratio * 50))
719
+ else:
720
+ logging.warning('No face detected, skipped!')
721
+ continue
722
+
723
+ images.append((
724
+ (-safe_v, -safe_r15, -safe_r18) if False else 0,
725
+ -face_ratio,
726
+ 1 if nsfw else 0,
727
+ 0 if img_name.startswith('pattern_') else 1,
728
+ img_name,
729
+ (local_img_file, img_filename, meta)
730
+ ))
731
+
732
+ for ptag in info.get('Prompt').split(','):
733
+ ptag = ptag.strip()
734
+ tags_count[ptag] = tags_count.get(ptag, 0) + 1
735
+ if ptag not in tags_idx:
736
+ tags_idx[ptag] = len(tags_idx)
737
+
738
+ images = [item[-1] for item in sorted(images)]
739
+ max_tag_cnt = max(tags_count.values())
740
+ recommended_tags = sorted([ptag for ptag, cnt in tags_count.items() if cnt == max_tag_cnt],
741
+ key=lambda x: tags_idx[x])
742
+
743
+ # publish model
744
+ session = session or get_civitai_session(timeout=30)
745
+
746
+ model_desc_default = f"""
747
+ * Thanks to Civitai's TOS, some images cannot be uploaded. **THE FULL PREVIEW IMAGES CAN BE FOUND ON [HUGGINGFACE](https://huggingface.co/{repo})**.
748
+ * **<span style="color:#fa5252">THIS MODEL HAS TWO FILES. YOU NEED TO USE THEM TOGETHER!!!</span>**
749
+ * **The associated trigger words are only for reference, it may need to be adjusted at some times**.
750
+ * Recommended weight of pt file is 0.5-1.0, weight of LoRA is 0.5-0.85.
751
+ * Images were generated using a few fixed prompts and dataset-based clustered prompts. Random seeds were used, ruling out cherry-picking. **What you see here is what you can get.**
752
+ * No specialized training was done for outfits. You can check our provided preview post to get the prompts corresponding to the outfits.
753
+ * This model is trained with **{plural_word(dataset_size, "image")}**.
754
+
755
+ ## How to Use This Model
756
+
757
+ **<span style="color:#fa5252">THIS MODEL HAS TWO FILES. YOU NEED TO USE THEM TOGETHER!!!</span>**.
758
+ In this case, you need to download both `{pt_file}` and
759
+ `{lora_file}`, then **use `{pt_file}` as texture inversion embedding, and use
760
+ `{lora_file}` as LoRA at the same time**.
761
+
762
+ **<span style="color:#fa5252">このモデルには2つのファイルがあります。一緒に使う必要があります!!!</span>**。
763
+ この場合、`{pt_file}`と`{lora_file}`の両方をダウンロード
764
+ する必要があります。`{pt_file}`をテクスチャ反転埋め込みとして使用し、同時に`{lora_file}`をLoRAとして使用してください。
765
+
766
+ **<span style="color:#fa5252">这个模型有两个文件。你需要同时使用它们!!!</span>**。
767
+ 在这种情况下,您需要下载`{pt_file}`和`{lora_file}`这两个文件,然后将`{pt_file}`用作纹理反转嵌入,
768
+ 同时使用`{lora_file}`作为LoRA。
769
+
770
+ **<span style="color:#fa5252">이 모델은 두 개의 파일이 있습니다. 두 파일을 함께 사용해야 합니다!!!</span>**.
771
+ 이 경우에는 `{pt_file}`와 `{lora_file}` 두 파일을 모두 다운로드하신 다음에 **`{pt_file}`을 텍스처 반전 임베딩으로 사용하고,
772
+ 동시에 `{lora_file}`을 LoRA로 사용하셔야 합니다**.
773
+
774
+ (Translated with ChatGPT)
775
+
776
+ The trigger word is `{trigger_word}`, and the recommended tags are `{', '.join(recommended_tags)}`.
777
+
778
+ ## How This Model Is Trained
779
+
780
+ This model is trained with [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion).
781
+ And the auto-training framework is maintained by [DeepGHS Team](https://huggingface.co/deepghs).
782
+ And the WebUI Panel provid by [LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI)
783
+
784
+ ## Why Some Preview Images Not Look Like {" ".join(map(str.capitalize, trigger_word.split("_")))}
785
+
786
+ **All the prompt texts** used on the preview images (which can be viewed by clicking on the images)
787
+ **are automatically generated using clustering algorithms** based on feature information extracted from the
788
+ training dataset. The seed used during image generation is also randomly generated, and the images have
789
+ not undergone any selection or modification. As a result, there is a possibility of the mentioned
790
+ issues occurring.
791
+
792
+ In practice, based on our internal testing, most models that experience such issues perform better in
793
+ actual usage than what is seen in the preview images. **The only thing you may need to do is adjusting
794
+ the tags you are using**.
795
+
796
+ ## I Felt This Model May Be Overfitting or Underfitting, What Shall I Do
797
+
798
+ Our model has been published on [huggingface repository - {repo}](https://huggingface.co/{repo}), where
799
+ models of all the steps are saved. Also, we published the training dataset on
800
+ [huggingface dataset - {repo}](https://huggingface.co/datasets/{repo}), which may be helpful to you.
801
+
802
+ ## Why Not Just Using The Better-Selected Images
803
+
804
+ Our model's entire process, from data crawling, training, to generating preview images and publishing,
805
+ is **100% automated without any human intervention**. It's an interesting experiment conducted by our team,
806
+ and for this purpose, we have developed a complete set of software infrastructure, including data filtering,
807
+ automatic training, and automated publishing. Therefore, if possible, we would appreciate more feedback or
808
+ suggestions as they are highly valuable to us.
809
+
810
+ ## Why Can't the Desired Character Outfits Be Accurately Generated
811
+
812
+ Our current training data is sourced from various image websites, and for a fully automated pipeline,
813
+ it's challenging to accurately predict which official images a character possesses.
814
+ Consequently, outfit generation relies on clustering based on labels from the training dataset
815
+ in an attempt to achieve the best possible recreation. We will continue to address this issue and attempt
816
+ optimization, but it remains a challenge that cannot be completely resolved. The accuracy of outfit
817
+ recreation is also unlikely to match the level achieved by manually trained models.
818
+
819
+ In fact, this model's greatest strengths lie in recreating the inherent characteristics of the characters
820
+ themselves and its relatively strong generalization capabilities, owing to its larger dataset.
821
+ As such, **this model is well-suited for tasks such as changing outfits, posing characters, and,
822
+ of course, generating NSFW images of characters!**😉".
823
+
824
+ For the following groups, it is not recommended to use this model and we express regret:
825
+
826
+ 1. Individuals who cannot tolerate any deviations from the original character design, even in the slightest detail.
827
+ 2. Individuals who are facing the application scenarios with high demands for accuracy in recreating character outfits.
828
+ 3. Individuals who cannot accept the potential randomness in AI-generated images based on the Stable Diffusion algorithm.
829
+ 4. Individuals who are not comfortable with the fully automated process of training character models using LoRA, or those who believe that training character models must be done purely through manual operations to avoid disrespecting the characters.
830
+ 5. Individuals who finds the generated image content offensive to their values.
831
+ """
832
+ model_name = model_name or try_find_title(char_name, game_name) or \
833
+ try_get_title_from_repo(repo) or trigger_word.replace('_', ' ')
834
+ if not force_create_model:
835
+ try:
836
+ exist_model = civitai_find_online(model_name, creator='narugo1992')
837
+ except ModelNotFound:
838
+ model_id = None
839
+ else:
840
+ logging.info(f'Existing model {exist_model.model_name}({exist_model.model_id}) found.')
841
+ model_id = exist_model.model_id
842
+ else:
843
+ model_id = None
844
+
845
+ model_id, _ = civitai_upsert_model(
846
+ name=model_name,
847
+ description_md=model_desc_md or model_desc_default,
848
+ tags=[
849
+ game_name, f"{game_name} {char_name}", char_name,
850
+ 'female', 'girl', 'character', 'fully-automated',
851
+ *map(_tag_decode, core_tags.keys()),
852
+ ],
853
+ exist_model_id=model_id,
854
+ session=session,
855
+ )
856
+
857
+ version_data = civitai_create_version(
858
+ model_id=model_id,
859
+ version_name=version_name,
860
+ description_md=version_desc_md or '',
861
+ trigger_words=[
862
+ trigger_word,
863
+ repr_tags([key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])]),
864
+ ],
865
+ session=session,
866
+ steps=step,
867
+ epochs=epoch,
868
+ )
869
+ version_id = version_data['id']
870
+
871
+ civitai_upload_models(
872
+ model_version_id=version_id,
873
+ model_files=models,
874
+ model_id=model_id,
875
+ session=session,
876
+ )
877
+ civitai_upload_images(
878
+ model_version_id=version_id,
879
+ image_files=images,
880
+ tags=[
881
+ game_name, f"{game_name} {char_name}", char_name,
882
+ 'female', 'girl', 'character', 'fully-automated', 'random prompt', 'random seed',
883
+ *map(_tag_decode, core_tags.keys()),
884
+ ],
885
+ model_id=model_id,
886
+ session=session,
887
+ )
888
+
889
+ if draft:
890
+ logging.info(f'Draft of model {model_id!r} created.')
891
+ else:
892
+ civiti_publish(model_id, version_id, publish_at, session)
893
+ return civitai_get_model_info(model_id, session)['id']
894
+
895
+
896
+ def get_draft_models(session=None):
897
+ session = session or get_civitai_session()
898
+ resp = srequest(
899
+ session, 'GET', 'https://civitai.com/api/trpc/model.getMyDraftModels',
900
+ params={
901
+ 'input': json.dumps({"json": {"page": 1, "limit": 200, "authed": True}}),
902
+ },
903
+ headers={'Referer': f'https://civitai.com/user'},
904
+ )
905
+ return resp.json()['result']['data']['json']['items']
906
+
907
+
908
+ def delete_model(model_id: int, session=None):
909
+ session = session or get_civitai_session()
910
+ resp = srequest(
911
+ session, 'POST', 'https://civitai.com/api/trpc/model.delete',
912
+ json={"json": {"id": model_id, "permanently": False, "authed": True}},
913
+ headers={'Referer': f'https://civitai.com/models/{model_id}'},
914
+ )
915
+ resp.raise_for_status()
cyberharem/publish/convert.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from hcpdiff.ckpt_manager import auto_manager
4
+ from hcpdiff.tools.lora_convert import LoraConverter
5
+
6
+
7
+ def convert_to_webui_lora(lora_path, lora_path_TE, dump_path, auto_scale_alpha: bool = True):
8
+ converter = LoraConverter()
9
+
10
+ # load lora model
11
+ logging.info(f'Converting lora model {lora_path!r} and {lora_path_TE!r} to {dump_path!r} ...')
12
+ ckpt_manager = auto_manager(lora_path)()
13
+
14
+ sd_unet = ckpt_manager.load_ckpt(lora_path)
15
+ sd_TE = ckpt_manager.load_ckpt(lora_path_TE)
16
+ state = converter.convert_to_webui(sd_unet['lora'], sd_TE['lora'])
17
+ if auto_scale_alpha:
18
+ state = {k: (v * v.shape[1] if 'lora_up' in k else v) for k, v in state.items()}
19
+ ckpt_manager._save_ckpt(state, save_path=dump_path)
cyberharem/publish/cyberharem_publish_huggingface.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import pathlib
4
+ import pytz
5
+ from typing import Optional
6
+
7
+ from ditk import logging
8
+ from hbutils.system import TemporaryDirectory
9
+ from huggingface_hub import CommitOperationAdd, CommitOperationDelete
10
+ from huggingface_hub.utils import RepositoryNotFoundError
11
+
12
+ from .export import export_workdir, _GITLFS
13
+ from .steps import find_steps_in_workdir
14
+ from ..infer.draw import _DEFAULT_INFER_MODEL
15
+ from ..utils import get_hf_client, get_hf_fs
16
+
17
+
18
+ def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3,
19
+ pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
20
+ image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
21
+ lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
22
+ model_hash: Optional[str] = None, ds_dir: str = None):
23
+ name, _ = find_steps_in_workdir(workdir)
24
+ repository = repository or f'AppleHarem/{name}'
25
+
26
+ logging.info(f'Initializing repository {repository!r} ...')
27
+ hf_client = get_hf_client()
28
+ hf_fs = get_hf_fs()
29
+ if not hf_fs.exists(f'{repository}/.gitattributes'):
30
+ hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True)
31
+
32
+ if not hf_fs.exists(f'{repository}/.gitattributes') or \
33
+ '*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'):
34
+ logging.info(f'Preparing for lfs attributes of repository {repository!r}.')
35
+ with TemporaryDirectory() as td:
36
+ _git_attr_file = os.path.join(td, '.gitattributes')
37
+ with open(_git_attr_file, 'w', encoding='utf-8') as f:
38
+ print(_GITLFS, file=f)
39
+
40
+ operations = [
41
+ CommitOperationAdd(
42
+ path_in_repo='.gitattributes',
43
+ path_or_fileobj=_git_attr_file,
44
+ )
45
+ ]
46
+ tokyo_tz = pytz.timezone('Asia/Tokyo')
47
+ current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
48
+ commit_message = f'Update {name}\'s .gitattributes, on {current_time}'
49
+ logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...')
50
+ hf_client.create_commit(
51
+ repository,
52
+ operations,
53
+ commit_message=commit_message,
54
+ repo_type='model',
55
+ revision=revision,
56
+ )
57
+
58
+ with TemporaryDirectory() as td:
59
+ export_workdir(
60
+ workdir, td, n_repeats, pretrained_model,
61
+ clip_skip, image_width, image_height, infer_steps,
62
+ lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集
63
+ )
64
+
65
+ try:
66
+ hf_client.repo_info(repo_id=repository, repo_type='dataset')
67
+ except RepositoryNotFoundError:
68
+ has_dataset_repo = False
69
+ else:
70
+ has_dataset_repo = True
71
+
72
+ readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8')
73
+ with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f:
74
+ print('---', file=f)
75
+ print('license: mit', file=f)
76
+ if has_dataset_repo:
77
+ print('datasets:', file=f)
78
+ print(f'- {repository}', file=f)
79
+ print('pipeline_tag: text-to-image', file=f)
80
+ print('tags:', file=f)
81
+ print('- art', file=f)
82
+ print('---', file=f)
83
+ print('', file=f)
84
+ print(readme_text, file=f)
85
+
86
+ _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')]
87
+ _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
88
+ pre_exist_files = set()
89
+ for i, (file, segments) in enumerate(_exist_ps):
90
+ if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
91
+ continue
92
+ if file != '.':
93
+ pre_exist_files.add(file)
94
+
95
+ operations = []
96
+ for directory, _, files in os.walk(td):
97
+ for file in files:
98
+ filename = os.path.abspath(os.path.join(td, directory, file))
99
+ file_in_repo = os.path.relpath(filename, td)
100
+ operations.append(CommitOperationAdd(
101
+ path_in_repo=file_in_repo,
102
+ path_or_fileobj=filename,
103
+ ))
104
+ if file_in_repo in pre_exist_files:
105
+ pre_exist_files.remove(file_in_repo)
106
+ logging.info(f'Useless files: {sorted(pre_exist_files)} ...')
107
+ for file in sorted(pre_exist_files):
108
+ operations.append(CommitOperationDelete(path_in_repo=file))
109
+
110
+ tokyo_tz = pytz.timezone('Asia/Tokyo')
111
+ current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
112
+ commit_message = f'Publish {name}\'s lora, on {current_time}'
113
+ logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...')
114
+ hf_client.create_commit(
115
+ repository,
116
+ operations,
117
+ commit_message=commit_message,
118
+ repo_type='model',
119
+ revision=revision,
120
+ )
cyberharem/publish/export.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os.path
4
+ import shutil
5
+ import time
6
+ import zipfile
7
+ from textwrap import dedent
8
+ from typing import Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from imgutils.metrics import ccip_extract_feature, ccip_batch_same
13
+ from tqdm.auto import tqdm
14
+ from waifuc.source import LocalSource
15
+
16
+ try:
17
+ import torch
18
+ except (ImportError, ModuleNotFoundError):
19
+ torch = None
20
+
21
+ from .convert import convert_to_webui_lora
22
+ from .steps import find_steps_in_workdir
23
+ from ..dataset import load_dataset_for_character
24
+ from ..dataset.tags import sort_draw_names
25
+ from ..infer.draw import _DEFAULT_INFER_MODEL
26
+ from ..infer.draw import draw_with_workdir
27
+ from ..utils import repr_tags, load_tags_from_directory
28
+
29
+ KNOWN_MODEL_HASHES = {
30
+ 'AIARTCHAN/anidosmixV2': 'EB49192009',
31
+ 'stablediffusionapi/anything-v5': None,
32
+ 'stablediffusionapi/cetusmix': 'B42B09FF12',
33
+ 'Meina/MeinaMix_V10': 'D967BCAE4A',
34
+ 'Meina/MeinaMix_V11': '54EF3E3610',
35
+ 'Lykon/DreamShaper': 'C33104F6',
36
+ 'digiplay/majicMIX_realistic_v6': 'EBDB94D4',
37
+ 'stablediffusionapi/abyssorangemix2nsfw': 'D6992792',
38
+ 'AIARTCHAN/expmixLine_v2': 'D91B18D1',
39
+ 'Yntec/CuteYuki2': 'FBE372BA',
40
+ 'stablediffusionapi/counterfeit-v30': '12047227',
41
+ 'jzli/XXMix_9realistic-v4': '5D22F204',
42
+ 'stablediffusionapi/flat-2d-animerge': 'F279CF76',
43
+ 'redstonehero/cetusmix_v4': '838408E0',
44
+ 'Meina/Unreal_V4.1': '0503BFAD',
45
+ 'Meina/MeinaHentai_V4': '39C0C3B6',
46
+ 'Meina/MeinaPastel_V6': 'DA1D535E',
47
+ 'KBlueLeaf/kohaku-v4-rev1.2': '87F9E45D',
48
+ 'stablediffusionapi/night-sky-yozora-sty': 'D31F707A',
49
+ }
50
+
51
+ EXPORT_MARK = 'v1.4.1'
52
+
53
+ _GITLFS = dedent("""
54
+ *.7z filter=lfs diff=lfs merge=lfs -text
55
+ *.arrow filter=lfs diff=lfs merge=lfs -text
56
+ *.bin filter=lfs diff=lfs merge=lfs -text
57
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
58
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
59
+ *.ftz filter=lfs diff=lfs merge=lfs -text
60
+ *.gz filter=lfs diff=lfs merge=lfs -text
61
+ *.h5 filter=lfs diff=lfs merge=lfs -text
62
+ *.joblib filter=lfs diff=lfs merge=lfs -text
63
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
64
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
65
+ *.model filter=lfs diff=lfs merge=lfs -text
66
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
67
+ *.npy filter=lfs diff=lfs merge=lfs -text
68
+ *.npz filter=lfs diff=lfs merge=lfs -text
69
+ *.onnx filter=lfs diff=lfs merge=lfs -text
70
+ *.ot filter=lfs diff=lfs merge=lfs -text
71
+ *.parquet filter=lfs diff=lfs merge=lfs -text
72
+ *.pb filter=lfs diff=lfs merge=lfs -text
73
+ *.pickle filter=lfs diff=lfs merge=lfs -text
74
+ *.pkl filter=lfs diff=lfs merge=lfs -text
75
+ *.pt filter=lfs diff=lfs merge=lfs -text
76
+ *.pth filter=lfs diff=lfs merge=lfs -text
77
+ *.rar filter=lfs diff=lfs merge=lfs -text
78
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
79
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
80
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
81
+ *.tar filter=lfs diff=lfs merge=lfs -text
82
+ *.tflite filter=lfs diff=lfs merge=lfs -text
83
+ *.tgz filter=lfs diff=lfs merge=lfs -text
84
+ *.wasm filter=lfs diff=lfs merge=lfs -text
85
+ *.xz filter=lfs diff=lfs merge=lfs -text
86
+ *.zip filter=lfs diff=lfs merge=lfs -text
87
+ *.zst filter=lfs diff=lfs merge=lfs -text
88
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
89
+ *.png filter=lfs diff=lfs merge=lfs -text
90
+ """).strip()
91
+
92
+
93
+ def export_workdir(workdir: str, export_dir: str, n_repeats: int = 2,
94
+ pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
95
+ image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
96
+ lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
97
+ model_hash: Optional[str] = None, ds_repo: Optional[str] = None):
98
+ name, steps = find_steps_in_workdir(workdir)
99
+ logging.info(f'Starting export trained artifacts of {name!r}, with steps: {steps!r}')
100
+ model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model, None)
101
+ if model_hash:
102
+ logging.info(f'Model hash {model_hash!r} detected for model {pretrained_model!r}.')
103
+
104
+ if os.path.exists(os.path.join(workdir, 'meta.json')):
105
+ with open(os.path.join(workdir, 'meta.json'), 'r', encoding='utf-8') as f:
106
+ dataset_info = json.load(f)['dataset']
107
+ else:
108
+ dataset_info = None
109
+
110
+ ds_repo = ds_repo or f'AppleHarem/{name}'
111
+ ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type']
112
+ logging.info(f'Loading dataset {ds_repo!r}, {ds_size!r} ...')
113
+ with load_dataset_for_character(ds_repo, ds_size) as (ch, ds_dir):
114
+ core_tags, _ = load_tags_from_directory(ds_dir)
115
+ ds_source = LocalSource(ds_dir)
116
+ ds_feats = []
117
+ for item in tqdm(list(ds_source), desc='Extract Dataset Feature'):
118
+ ds_feats.append(ccip_extract_feature(item.image))
119
+
120
+ d_names = set()
121
+ all_drawings = {}
122
+ nsfw_count = {}
123
+ all_scores = {}
124
+ all_scores_lst = []
125
+ for step in steps:
126
+ logging.info(f'Exporting for {name}-{step} ...')
127
+ step_dir = os.path.join(export_dir, f'{step}')
128
+ os.makedirs(step_dir, exist_ok=True)
129
+
130
+ preview_dir = os.path.join(step_dir, 'previews')
131
+ os.makedirs(preview_dir, exist_ok=True)
132
+
133
+ while True:
134
+ try:
135
+ drawings = draw_with_workdir(
136
+ workdir, model_steps=step, n_repeats=n_repeats,
137
+ pretrained_model=pretrained_model,
138
+ width=image_width, height=image_height, infer_steps=infer_steps,
139
+ lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method,
140
+ model_hash=model_hash,
141
+ )
142
+ except RuntimeError:
143
+ n_repeats += 1
144
+ else:
145
+ break
146
+
147
+ all_image_files = []
148
+ image_feats = []
149
+ for draw in drawings:
150
+ img_file = os.path.join(preview_dir, f'{draw.name}.png')
151
+ image_feats.append(ccip_extract_feature(draw.image))
152
+ draw.image.save(img_file, pnginfo=draw.pnginfo)
153
+ all_image_files.append(img_file)
154
+
155
+ with open(os.path.join(preview_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f:
156
+ print(draw.preview_info, file=f)
157
+ d_names.add(draw.name)
158
+ all_drawings[(draw.name, step)] = draw
159
+ if not draw.sfw:
160
+ nsfw_count[draw.name] = nsfw_count.get(draw.name, 0) + 1
161
+
162
+ pt_file = os.path.join(workdir, 'ckpts', f'{name}-{step}.pt')
163
+ unet_file = os.path.join(workdir, 'ckpts', f'unet-{step}.safetensors')
164
+ text_encoder_file = os.path.join(workdir, 'ckpts', f'text_encoder-{step}.safetensors')
165
+ raw_dir = os.path.join(step_dir, 'raw')
166
+ os.makedirs(raw_dir, exist_ok=True)
167
+ shutil.copyfile(pt_file, os.path.join(raw_dir, os.path.basename(pt_file)))
168
+ shutil.copyfile(unet_file, os.path.join(raw_dir, os.path.basename(unet_file)))
169
+ shutil.copyfile(text_encoder_file, os.path.join(raw_dir, os.path.basename(text_encoder_file)))
170
+
171
+ shutil.copyfile(pt_file, os.path.join(step_dir, f'{name}.pt'))
172
+ convert_to_webui_lora(unet_file, text_encoder_file, os.path.join(step_dir, f'{name}.safetensors'))
173
+ with zipfile.ZipFile(os.path.join(step_dir, f'{name}.zip'), 'w') as zf:
174
+ zf.write(os.path.join(step_dir, f'{name}.pt'), f'{name}.pt')
175
+ zf.write(os.path.join(step_dir, f'{name}.safetensors'), f'{name}.safetensors')
176
+ for img_file in all_image_files:
177
+ zf.write(img_file, os.path.basename(img_file))
178
+
179
+ same_matrix = ccip_batch_same([*image_feats, *ds_feats])
180
+ score = same_matrix[:len(image_feats), len(image_feats):].mean()
181
+ all_scores[step] = score
182
+ all_scores_lst.append(score)
183
+ logging.info(f'Score of step {step} is {score}.')
184
+
185
+ lst_scores = np.array(all_scores_lst)
186
+ lst_steps = np.array(steps)
187
+ if dataset_info and 'size' in dataset_info:
188
+ min_best_steps = 6 * dataset_info['size']
189
+ _lst_scores = lst_scores[lst_steps >= min_best_steps]
190
+ _lst_steps = lst_steps[lst_steps >= min_best_steps]
191
+ if _lst_scores.shape[0] > 0:
192
+ lst_steps, lst_scores = _lst_steps, _lst_scores
193
+
194
+ best_idx = np.argmax(lst_scores)
195
+ best_step = lst_steps[best_idx].item()
196
+ nsfw_ratio = {name: count * 1.0 / len(steps) for name, count in nsfw_count.items()}
197
+ with open(os.path.join(export_dir, 'meta.json'), 'w', encoding='utf-8') as f:
198
+ json.dump({
199
+ 'name': name,
200
+ 'steps': steps,
201
+ 'mark': EXPORT_MARK,
202
+ 'time': time.time(),
203
+ 'dataset': dataset_info,
204
+ 'scores': [
205
+ {
206
+ 'step': step,
207
+ 'score': score,
208
+ } for step, score in sorted(all_scores.items())
209
+ ],
210
+ 'best_step': best_step,
211
+ }, f, ensure_ascii=False, indent=4)
212
+ with open(os.path.join(export_dir, '.gitattributes'), 'w', encoding='utf-8') as f:
213
+ print(_GITLFS, file=f)
214
+ with open(os.path.join(export_dir, 'README.md'), 'w', encoding='utf-8') as f:
215
+ print(f'# Lora of {name}', file=f)
216
+ print('', file=f)
217
+
218
+ print('This model is trained with [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion). '
219
+ 'And the auto-training framework is maintained by '
220
+ '[DeepGHS Team](https://huggingface.co/deepghs).'
221
+ 'And the WebUI Panel provid by [LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI)', file=f)
222
+ print('', file=f)
223
+
224
+ print('The base model used during training is [NAI](https://huggingface.co/deepghs/animefull-latest), '
225
+ f'and the base model used for generating preview images is '
226
+ f'[{pretrained_model}](https://huggingface.co/{pretrained_model}).', file=f)
227
+ print('', file=f)
228
+
229
+ print(f'After downloading the pt and safetensors files for the specified step, '
230
+ f'you need to use them simultaneously. The pt file will be used as an embedding, '
231
+ f'while the safetensors file will be loaded for Lora.', file=f)
232
+ print('', file=f)
233
+ print(f'For example, if you want to use the model from step {best_step}, '
234
+ f'you need to download `{best_step}/{name}.pt` as the embedding and '
235
+ f'`{best_step}/{name}.safetensors` for loading Lora. '
236
+ f'By using both files together, you can generate images for the desired characters.', file=f)
237
+ print('', file=f)
238
+
239
+ print(dedent(f"""
240
+ **The best step we recommend is {best_step}**, with the score of {all_scores[best_step]:.3f}. The trigger words are:
241
+ 1. `{name}`
242
+ 2. `{repr_tags([key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])])}`
243
+ """).strip(), file=f)
244
+ print('', file=f)
245
+
246
+ print(dedent("""
247
+ For the following groups, it is not recommended to use this model and we express regret:
248
+ 1. Individuals who cannot tolerate any deviations from the original character design, even in the slightest detail.
249
+ 2. Individuals who are facing the application scenarios with high demands for accuracy in recreating character outfits.
250
+ 3. Individuals who cannot accept the potential randomness in AI-generated images based on the Stable Diffusion algorithm.
251
+ 4. Individuals who are not comfortable with the fully automated process of training character models using LoRA, or those who believe that training character models must be done purely through manual operations to avoid disrespecting the characters.
252
+ 5. Individuals who finds the generated image content offensive to their values.
253
+ """).strip(), file=f)
254
+ print('', file=f)
255
+
256
+ print(f'These are available steps:', file=f)
257
+ print('', file=f)
258
+
259
+ d_names = sort_draw_names(list(d_names))
260
+ columns = ['Steps', 'Score', 'Download', *d_names]
261
+ t_data = []
262
+
263
+ for step in steps[::-1]:
264
+ d_mds = []
265
+ for dname in d_names:
266
+ file = os.path.join(str(step), 'previews', f'{dname}.png')
267
+ if (dname, step) in all_drawings:
268
+ if nsfw_ratio.get(dname, 0.0) < 0.35:
269
+ d_mds.append(f'![{dname}-{step}]({file})')
270
+ else:
271
+ d_mds.append(f'[<NSFW, click to see>]({file})')
272
+ else:
273
+ d_mds.append('')
274
+
275
+ t_data.append((
276
+ str(step) if step != best_step else f'**{step}**',
277
+ f'{all_scores[step]:.3f}' if step != best_step else f'**{all_scores[step]:.3f}**',
278
+ f'[Download]({step}/{name}.zip)' if step != best_step else f'[**Download**]({step}/{name}.zip)',
279
+ *d_mds,
280
+ ))
281
+
282
+ df = pd.DataFrame(columns=columns, data=t_data)
283
+ print(df.to_markdown(index=False), file=f)
284
+ print('', file=f)
cyberharem/publish/huggingface.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import pathlib
4
+ import pytz
5
+ from typing import Optional
6
+
7
+ from ditk import logging
8
+ from hbutils.system import TemporaryDirectory
9
+ from huggingface_hub import CommitOperationAdd, CommitOperationDelete
10
+ from huggingface_hub.utils import RepositoryNotFoundError
11
+
12
+ from .export import export_workdir, _GITLFS
13
+ from .steps import find_steps_in_workdir
14
+ from ..infer.draw import _DEFAULT_INFER_MODEL
15
+ from ..utils import get_hf_client, get_hf_fs
16
+
17
+
18
+ def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3,
19
+ pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
20
+ image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
21
+ lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
22
+ model_hash: Optional[str] = None, ds_dir: str = None):
23
+ name, _ = find_steps_in_workdir(workdir)
24
+ repository = repository or f'AppleHarem/{name}'
25
+
26
+ logging.info(f'Initializing repository {repository!r} ...')
27
+ hf_client = get_hf_client()
28
+ hf_fs = get_hf_fs()
29
+ if not hf_fs.exists(f'{repository}/.gitattributes'):
30
+ hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True)
31
+
32
+ if not hf_fs.exists(f'{repository}/.gitattributes') or \
33
+ '*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'):
34
+ logging.info(f'Preparing for lfs attributes of repository {repository!r}.')
35
+ with TemporaryDirectory() as td:
36
+ _git_attr_file = os.path.join(td, '.gitattributes')
37
+ with open(_git_attr_file, 'w', encoding='utf-8') as f:
38
+ print(_GITLFS, file=f)
39
+
40
+ operations = [
41
+ CommitOperationAdd(
42
+ path_in_repo='.gitattributes',
43
+ path_or_fileobj=_git_attr_file,
44
+ )
45
+ ]
46
+ tokyo_tz = pytz.timezone('Asia/Tokyo')
47
+ current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
48
+ commit_message = f'Update {name}\'s .gitattributes, on {current_time}'
49
+ logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...')
50
+ hf_client.create_commit(
51
+ repository,
52
+ operations,
53
+ commit_message=commit_message,
54
+ repo_type='model',
55
+ revision=revision,
56
+ )
57
+
58
+ with TemporaryDirectory() as td:
59
+ export_workdir(
60
+ workdir, td, n_repeats, pretrained_model,
61
+ clip_skip, image_width, image_height, infer_steps,
62
+ lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集
63
+ )
64
+
65
+ try:
66
+ hf_client.repo_info(repo_id=repository, repo_type='dataset')
67
+ except RepositoryNotFoundError:
68
+ has_dataset_repo = False
69
+ else:
70
+ has_dataset_repo = True
71
+
72
+ readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8')
73
+ with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f:
74
+ print('---', file=f)
75
+ print('license: mit', file=f)
76
+ if has_dataset_repo:
77
+ print('datasets:', file=f)
78
+ print(f'- {repository}', file=f)
79
+ print('pipeline_tag: text-to-image', file=f)
80
+ print('tags:', file=f)
81
+ print('- art', file=f)
82
+ print('---', file=f)
83
+ print('', file=f)
84
+ print(readme_text, file=f)
85
+
86
+ _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')]
87
+ _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
88
+ pre_exist_files = set()
89
+ for i, (file, segments) in enumerate(_exist_ps):
90
+ if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
91
+ continue
92
+ if file != '.':
93
+ pre_exist_files.add(file)
94
+
95
+ operations = []
96
+ for directory, _, files in os.walk(td):
97
+ for file in files:
98
+ filename = os.path.abspath(os.path.join(td, directory, file))
99
+ file_in_repo = os.path.relpath(filename, td)
100
+ operations.append(CommitOperationAdd(
101
+ path_in_repo=file_in_repo,
102
+ path_or_fileobj=filename,
103
+ ))
104
+ if file_in_repo in pre_exist_files:
105
+ pre_exist_files.remove(file_in_repo)
106
+ logging.info(f'Useless files: {sorted(pre_exist_files)} ...')
107
+ for file in sorted(pre_exist_files):
108
+ operations.append(CommitOperationDelete(path_in_repo=file))
109
+
110
+ tokyo_tz = pytz.timezone('Asia/Tokyo')
111
+ current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
112
+ commit_message = f'Publish {name}\'s lora, on {current_time}'
113
+ logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...')
114
+ hf_client.create_commit(
115
+ repository,
116
+ operations,
117
+ commit_message=commit_message,
118
+ repo_type='model',
119
+ revision=revision,
120
+ )
cyberharem/publish/steps.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os.path
3
+ from typing import List, Tuple
4
+
5
+
6
+ def find_steps_in_workdir(workdir: str) -> Tuple[str, List[int]]:
7
+ ckpts_dir = os.path.join(workdir, 'ckpts')
8
+ pt_steps = []
9
+ pt_name = None
10
+ for pt in glob.glob(os.path.join(ckpts_dir, '*-*.pt')):
11
+ name = os.path.basename(pt)
12
+ segs = os.path.splitext(name)[0].split('-')
13
+ if pt_name is None:
14
+ pt_name = '-'.join(segs[:-1])
15
+ else:
16
+ if pt_name != '-'.join(segs[:-1]):
17
+ raise NameError(f'Name not match, {pt_name!r} vs {"-".join(segs[:-1])!r}.')
18
+ pt_steps.append(int(segs[-1]))
19
+
20
+ unet_steps = []
21
+ for unet in glob.glob(os.path.join(ckpts_dir, 'unet-*.safetensors')):
22
+ name = os.path.basename(unet)
23
+ segs = os.path.splitext(name)[0].split('-')
24
+ unet_steps.append(int(segs[-1]))
25
+
26
+ text_encoder_steps = []
27
+ for text_encoder in glob.glob(os.path.join(ckpts_dir, 'text_encoder-*.safetensors')):
28
+ name = os.path.basename(text_encoder)
29
+ segs = os.path.splitext(name)[0].split('-')
30
+ text_encoder_steps.append(int(segs[-1]))
31
+
32
+ return pt_name, sorted(set(pt_steps) & set(unet_steps) & set(text_encoder_steps))