Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. ย
See raw diff
- .gitattributes +3 -0
- .gitignore +68 -0
- .gitmodules +32 -0
- .pylintrc +237 -0
- CHANGELOG.md +0 -0
- CITATION.cff +28 -0
- README.md +277 -8
- SECURITY.md +36 -0
- TODO.md +20 -0
- cli/README.md +108 -0
- cli/clone.py +78 -0
- cli/create-previews.py +346 -0
- cli/download.py +126 -0
- cli/gen-styles.py +79 -0
- cli/generate.json +38 -0
- cli/generate.py +373 -0
- cli/hf-convert.py +35 -0
- cli/hf-search.py +18 -0
- cli/idle.py +60 -0
- cli/image-exif.py +163 -0
- cli/image-grid.py +128 -0
- cli/image-interrogate.py +109 -0
- cli/image-palette.py +129 -0
- cli/image-watermark.py +129 -0
- cli/install-sf.py +87 -0
- cli/latents.py +170 -0
- cli/lcm-convert.py +55 -0
- cli/model-jit.py +177 -0
- cli/model-metadata.py +41 -0
- cli/nvidia-smi.py +35 -0
- cli/options.py +141 -0
- cli/process.py +327 -0
- cli/random.json +31 -0
- cli/requirements.txt +7 -0
- cli/run-benchmark.py +149 -0
- cli/sdapi.py +262 -0
- cli/simple-img2img.py +98 -0
- cli/simple-info.py +57 -0
- cli/simple-mask.py +83 -0
- cli/simple-preprocess.py +76 -0
- cli/simple-txt2img.js +63 -0
- cli/simple-txt2img.py +80 -0
- cli/simple-upscale.py +90 -0
- cli/torch-compile.py +99 -0
- cli/train.py +443 -0
- cli/util.py +113 -0
- cli/validate-locale.py +40 -0
- cli/video-extract.py +71 -0
- configs/alt-diffusion-inference.yaml +72 -0
- configs/instruct-pix2pix.yaml +98 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
extensions-builtin/sd-webui-agent-scheduler/docs/images/walkthrough.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
extensions-builtin/stable-diffusion-webui-rembg/preview.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
javascript/notosans-nerdfont-regular.ttf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# defaults
|
2 |
+
__pycache__
|
3 |
+
.ruff_cache
|
4 |
+
/cache.json
|
5 |
+
/*.json
|
6 |
+
/*.yaml
|
7 |
+
/params.txt
|
8 |
+
/styles.csv
|
9 |
+
/user.css
|
10 |
+
/webui-user.bat
|
11 |
+
/webui-user.sh
|
12 |
+
/html/extensions.json
|
13 |
+
/html/themes.json
|
14 |
+
node_modules
|
15 |
+
pnpm-lock.yaml
|
16 |
+
package-lock.json
|
17 |
+
venv
|
18 |
+
.history
|
19 |
+
cache
|
20 |
+
**/.DS_Store
|
21 |
+
|
22 |
+
# all models and temp files
|
23 |
+
*.log
|
24 |
+
*.log.*
|
25 |
+
*.bak
|
26 |
+
*.ckpt
|
27 |
+
*.safetensors
|
28 |
+
*.pth
|
29 |
+
*.pt
|
30 |
+
*.bin
|
31 |
+
*.optim
|
32 |
+
*.lock
|
33 |
+
*.zip
|
34 |
+
*.rar
|
35 |
+
*.7z
|
36 |
+
*.pyc
|
37 |
+
/*.bat
|
38 |
+
/*.sh
|
39 |
+
/*.txt
|
40 |
+
/*.mp3
|
41 |
+
/*.lnk
|
42 |
+
!webui.bat
|
43 |
+
!webui.sh
|
44 |
+
!package.json
|
45 |
+
|
46 |
+
# all dynamic stuff
|
47 |
+
/extensions/**/*
|
48 |
+
/outputs/**/*
|
49 |
+
/embeddings/**/*
|
50 |
+
/models/**/*
|
51 |
+
/interrogate/**/*
|
52 |
+
/train/log/**/*
|
53 |
+
/textual_inversion/**/*
|
54 |
+
/detected_maps/**/*
|
55 |
+
/tmp
|
56 |
+
/log
|
57 |
+
/cert
|
58 |
+
.vscode/
|
59 |
+
.idea/
|
60 |
+
/localizations
|
61 |
+
|
62 |
+
.*/
|
63 |
+
|
64 |
+
# force included
|
65 |
+
!/models/VAE-approx
|
66 |
+
!/models/VAE-approx/model.pt
|
67 |
+
!/models/Reference
|
68 |
+
!/models/Reference/**/*
|
.gitmodules
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "wiki"]
|
2 |
+
path = wiki
|
3 |
+
url = https://github.com/vladmandic/automatic.wiki
|
4 |
+
ignore = dirty
|
5 |
+
[submodule "modules/k-diffusion"]
|
6 |
+
path = modules/k-diffusion
|
7 |
+
url = https://github.com/crowsonkb/k-diffusion
|
8 |
+
ignore = dirty
|
9 |
+
[submodule "extensions-builtin/sd-extension-system-info"]
|
10 |
+
path = extensions-builtin/sd-extension-system-info
|
11 |
+
url = https://github.com/vladmandic/sd-extension-system-info
|
12 |
+
ignore = dirty
|
13 |
+
[submodule "extensions-builtin/sd-extension-chainner"]
|
14 |
+
path = extensions-builtin/sd-extension-chainner
|
15 |
+
url = https://github.com/vladmandic/sd-extension-chainner
|
16 |
+
ignore = dirty
|
17 |
+
[submodule "extensions-builtin/stable-diffusion-webui-rembg"]
|
18 |
+
path = extensions-builtin/stable-diffusion-webui-rembg
|
19 |
+
url = https://github.com/vladmandic/sd-extension-rembg
|
20 |
+
ignore = dirty
|
21 |
+
[submodule "extensions-builtin/stable-diffusion-webui-images-browser"]
|
22 |
+
path = extensions-builtin/stable-diffusion-webui-images-browser
|
23 |
+
url = https://github.com/AlUlkesh/stable-diffusion-webui-images-browser
|
24 |
+
ignore = dirty
|
25 |
+
[submodule "extensions-builtin/sd-webui-controlnet"]
|
26 |
+
path = extensions-builtin/sd-webui-controlnet
|
27 |
+
url = https://github.com/Mikubill/sd-webui-controlnet
|
28 |
+
ignore = dirty
|
29 |
+
[submodule "extensions-builtin/sd-webui-agent-scheduler"]
|
30 |
+
path = extensions-builtin/sd-webui-agent-scheduler
|
31 |
+
url = https://github.com/ArtVentureX/sd-webui-agent-scheduler
|
32 |
+
ignore = dirty
|
.pylintrc
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[MAIN]
|
2 |
+
analyse-fallback-blocks=no
|
3 |
+
clear-cache-post-run=no
|
4 |
+
#enable-all-extensions=
|
5 |
+
#errors-only=
|
6 |
+
#exit-zero=
|
7 |
+
extension-pkg-allow-list=
|
8 |
+
extension-pkg-whitelist=
|
9 |
+
fail-on=
|
10 |
+
fail-under=10
|
11 |
+
ignore=CVS
|
12 |
+
ignore-paths=/usr/lib/.*$,
|
13 |
+
^repositories/.*$,
|
14 |
+
^extensions/.*$,
|
15 |
+
^extensions-builtin/.*$,
|
16 |
+
^modules/dml/.*$,
|
17 |
+
^modules/tcd/.*$,
|
18 |
+
^modules/xadapters/.*$,
|
19 |
+
ignore-patterns=
|
20 |
+
ignored-modules=
|
21 |
+
jobs=0
|
22 |
+
limit-inference-results=100
|
23 |
+
load-plugins=
|
24 |
+
persistent=yes
|
25 |
+
py-version=3.10
|
26 |
+
recursive=no
|
27 |
+
source-roots=
|
28 |
+
suggestion-mode=yes
|
29 |
+
unsafe-load-any-extension=no
|
30 |
+
#verbose=
|
31 |
+
|
32 |
+
[BASIC]
|
33 |
+
argument-naming-style=snake_case
|
34 |
+
#argument-rgx=
|
35 |
+
attr-naming-style=snake_case
|
36 |
+
#attr-rgx=
|
37 |
+
bad-names=foo, bar, baz, toto, tutu, tata
|
38 |
+
bad-names-rgxs=
|
39 |
+
class-attribute-naming-style=any
|
40 |
+
class-const-naming-style=UPPER_CASE
|
41 |
+
#class-const-rgx=
|
42 |
+
class-naming-style=PascalCase
|
43 |
+
#class-rgx=
|
44 |
+
const-naming-style=snake_case
|
45 |
+
#const-rgx=
|
46 |
+
docstring-min-length=-1
|
47 |
+
function-naming-style=snake_case
|
48 |
+
#function-rgx=
|
49 |
+
# Good variable names which should always be accepted, separated by a comma.
|
50 |
+
good-names=i,j,k,e,ex,ok,p
|
51 |
+
good-names-rgxs=
|
52 |
+
include-naming-hint=no
|
53 |
+
inlinevar-naming-style=any
|
54 |
+
#inlinevar-rgx=
|
55 |
+
method-naming-style=snake_case
|
56 |
+
#method-rgx=
|
57 |
+
module-naming-style=snake_case
|
58 |
+
#module-rgx=
|
59 |
+
name-group=
|
60 |
+
no-docstring-rgx=^_
|
61 |
+
property-classes=abc.abstractproperty
|
62 |
+
#typealias-rgx=
|
63 |
+
#typevar-rgx=
|
64 |
+
variable-naming-style=snake_case
|
65 |
+
#variable-rgx=
|
66 |
+
|
67 |
+
[CLASSES]
|
68 |
+
check-protected-access-in-special-methods=no
|
69 |
+
defining-attr-methods=__init__,
|
70 |
+
__new__,
|
71 |
+
setUp,
|
72 |
+
asyncSetUp,
|
73 |
+
__post_init__
|
74 |
+
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
|
75 |
+
valid-classmethod-first-arg=cls
|
76 |
+
valid-metaclass-classmethod-first-arg=mcs
|
77 |
+
|
78 |
+
[DESIGN]
|
79 |
+
exclude-too-few-public-methods=
|
80 |
+
ignored-parents=
|
81 |
+
max-args=99
|
82 |
+
max-attributes=99
|
83 |
+
max-bool-expr=99
|
84 |
+
max-branches=99
|
85 |
+
max-locals=99
|
86 |
+
max-parents=99
|
87 |
+
max-public-methods=99
|
88 |
+
max-returns=99
|
89 |
+
max-statements=199
|
90 |
+
min-public-methods=1
|
91 |
+
|
92 |
+
[EXCEPTIONS]
|
93 |
+
overgeneral-exceptions=builtins.BaseException,builtins.Exception
|
94 |
+
|
95 |
+
[FORMAT]
|
96 |
+
expected-line-ending-format=
|
97 |
+
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
98 |
+
indent-after-paren=4
|
99 |
+
indent-string=' '
|
100 |
+
max-line-length=200
|
101 |
+
max-module-lines=9999
|
102 |
+
single-line-class-stmt=no
|
103 |
+
single-line-if-stmt=no
|
104 |
+
|
105 |
+
[IMPORTS]
|
106 |
+
allow-any-import-level=
|
107 |
+
allow-reexport-from-package=no
|
108 |
+
allow-wildcard-with-all=no
|
109 |
+
deprecated-modules=
|
110 |
+
ext-import-graph=
|
111 |
+
import-graph=
|
112 |
+
int-import-graph=
|
113 |
+
known-standard-library=
|
114 |
+
known-third-party=enchant
|
115 |
+
preferred-modules=
|
116 |
+
|
117 |
+
[LOGGING]
|
118 |
+
logging-format-style=new
|
119 |
+
logging-modules=logging
|
120 |
+
|
121 |
+
[MESSAGES CONTROL]
|
122 |
+
confidence=HIGH,
|
123 |
+
CONTROL_FLOW,
|
124 |
+
INFERENCE,
|
125 |
+
INFERENCE_FAILURE,
|
126 |
+
UNDEFINED
|
127 |
+
# disable=C,R,W
|
128 |
+
disable=bad-inline-option,
|
129 |
+
bare-except,
|
130 |
+
broad-exception-caught,
|
131 |
+
chained-comparison,
|
132 |
+
consider-iterating-dictionary,
|
133 |
+
consider-using-dict-items,
|
134 |
+
consider-using-generator,
|
135 |
+
consider-using-enumerate,
|
136 |
+
consider-using-sys-exit,
|
137 |
+
consider-using-from-import,
|
138 |
+
consider-using-get,
|
139 |
+
consider-using-in,
|
140 |
+
consider-using-min-builtin,
|
141 |
+
dangerous-default-value,
|
142 |
+
deprecated-pragma,
|
143 |
+
duplicate-code,
|
144 |
+
file-ignored,
|
145 |
+
import-error,
|
146 |
+
import-outside-toplevel,
|
147 |
+
invalid-name,
|
148 |
+
line-too-long,
|
149 |
+
locally-disabled,
|
150 |
+
logging-fstring-interpolation,
|
151 |
+
missing-class-docstring,
|
152 |
+
missing-function-docstring,
|
153 |
+
missing-module-docstring,
|
154 |
+
no-else-return,
|
155 |
+
not-callable,
|
156 |
+
pointless-string-statement,
|
157 |
+
raw-checker-failed,
|
158 |
+
simplifiable-if-expression,
|
159 |
+
suppressed-message,
|
160 |
+
too-many-nested-blocks,
|
161 |
+
too-few-public-methods,
|
162 |
+
too-many-statements,
|
163 |
+
too-many-locals,
|
164 |
+
too-many-instance-attributes,
|
165 |
+
unnecessary-dunder-call,
|
166 |
+
unnecessary-lambda,
|
167 |
+
use-dict-literal,
|
168 |
+
use-symbolic-message-instead,
|
169 |
+
useless-suppression,
|
170 |
+
unidiomatic-typecheck,
|
171 |
+
wrong-import-position
|
172 |
+
enable=c-extension-no-member
|
173 |
+
|
174 |
+
[METHOD_ARGS]
|
175 |
+
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
|
176 |
+
|
177 |
+
[MISCELLANEOUS]
|
178 |
+
notes=FIXME,
|
179 |
+
XXX,
|
180 |
+
TODO
|
181 |
+
notes-rgx=
|
182 |
+
|
183 |
+
[REFACTORING]
|
184 |
+
max-nested-blocks=5
|
185 |
+
never-returning-functions=sys.exit,argparse.parse_error
|
186 |
+
|
187 |
+
[REPORTS]
|
188 |
+
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
|
189 |
+
msg-template=
|
190 |
+
#output-format=
|
191 |
+
reports=no
|
192 |
+
score=no
|
193 |
+
|
194 |
+
[SIMILARITIES]
|
195 |
+
ignore-comments=yes
|
196 |
+
ignore-docstrings=yes
|
197 |
+
ignore-imports=yes
|
198 |
+
ignore-signatures=yes
|
199 |
+
min-similarity-lines=4
|
200 |
+
|
201 |
+
[SPELLING]
|
202 |
+
max-spelling-suggestions=4
|
203 |
+
spelling-dict=
|
204 |
+
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
|
205 |
+
spelling-ignore-words=
|
206 |
+
spelling-private-dict-file=
|
207 |
+
spelling-store-unknown-words=no
|
208 |
+
|
209 |
+
[STRING]
|
210 |
+
check-quote-consistency=no
|
211 |
+
check-str-concat-over-line-jumps=no
|
212 |
+
|
213 |
+
[TYPECHECK]
|
214 |
+
contextmanager-decorators=contextlib.contextmanager
|
215 |
+
generated-members=numpy.*,logging.*,torch.*,cv2.*
|
216 |
+
ignore-none=yes
|
217 |
+
ignore-on-opaque-inference=yes
|
218 |
+
ignored-checks-for-mixins=no-member,
|
219 |
+
not-async-context-manager,
|
220 |
+
not-context-manager,
|
221 |
+
attribute-defined-outside-init
|
222 |
+
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
|
223 |
+
missing-member-hint=yes
|
224 |
+
missing-member-hint-distance=1
|
225 |
+
missing-member-max-choices=1
|
226 |
+
mixin-class-rgx=.*[Mm]ixin
|
227 |
+
signature-mutators=
|
228 |
+
|
229 |
+
[VARIABLES]
|
230 |
+
additional-builtins=
|
231 |
+
allow-global-unused-variables=yes
|
232 |
+
allowed-redefined-builtins=
|
233 |
+
callbacks=cb_,
|
234 |
+
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
235 |
+
ignored-argument-names=_.*|^ignored_|^unused_
|
236 |
+
init-import=no
|
237 |
+
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
CHANGELOG.md
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CITATION.cff
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
title: SD.Next
|
3 |
+
url: 'https://github.com/vladmandic/automatic'
|
4 |
+
message: >-
|
5 |
+
If you use this software, please cite it using the
|
6 |
+
metadata from this file
|
7 |
+
type: software
|
8 |
+
authors:
|
9 |
+
- given-names: Vladimir
|
10 |
+
name-particle: Vlado
|
11 |
+
family-names: Mandic
|
12 |
+
orcid: 'https://orcid.org/0009-0003-4592-5074'
|
13 |
+
identifiers:
|
14 |
+
- type: url
|
15 |
+
value: 'https://github.com/vladmandic'
|
16 |
+
description: GitHub
|
17 |
+
- type: url
|
18 |
+
value: 'https://www.linkedin.com/in/cyan051/'
|
19 |
+
description: LinkedIn
|
20 |
+
repository-code: 'https://github.com/vladmandic/automatic'
|
21 |
+
abstract: >-
|
22 |
+
SD.Next: Advanced Implementation of Stable Diffusion and
|
23 |
+
other diffusion models for text, image and video
|
24 |
+
generation
|
25 |
+
keywords:
|
26 |
+
- stablediffusion diffusers sdnext
|
27 |
+
license: AGPL-3.0
|
28 |
+
date-released: 2022-12-24
|
README.md
CHANGED
@@ -1,12 +1,281 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: test
|
3 |
+
app_file: webui.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.29.0
|
|
|
|
|
6 |
---
|
7 |
+
<div align="center">
|
8 |
|
9 |
+
# SD.Next
|
10 |
+
|
11 |
+
**Stable Diffusion implementation with advanced features**
|
12 |
+
|
13 |
+
[![Sponsors](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/vladmandic)
|
14 |
+
![Last Commit](https://img.shields.io/github/last-commit/vladmandic/automatic?svg=true)
|
15 |
+
![License](https://img.shields.io/github/license/vladmandic/automatic?svg=true)
|
16 |
+
[![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX)
|
17 |
+
|
18 |
+
[Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)
|
19 |
+
|
20 |
+
</div>
|
21 |
+
</br>
|
22 |
+
|
23 |
+
## Notable features
|
24 |
+
|
25 |
+
All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
|
26 |
+
- Multiple backends!
|
27 |
+
โน **Diffusers | Original**
|
28 |
+
- Multiple diffusion models!
|
29 |
+
โน **Stable Diffusion 1.5/2.1 | SD-XL | LCM | Segmind | Kandinsky | Pixart-ฮฑ | Stable Cascade | Wรผrstchen | aMUSEd | DeepFloyd IF | UniDiffusion | SD-Distilled | BLiP Diffusion | KOALA | etc.**
|
30 |
+
- Built-in Control for Text, Image, Batch and video processing!
|
31 |
+
โน **ControlNet | ControlNet XS | Control LLLite | T2I Adapters | IP Adapters**
|
32 |
+
- Multiplatform!
|
33 |
+
โน **Windows | Linux | MacOS with CPU | nVidia | AMD | IntelArc | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
|
34 |
+
- Platform specific autodetection and tuning performed on install
|
35 |
+
- Optimized processing with latest `torch` developments with built-in support for `torch.compile`
|
36 |
+
and multiple compile backends: *Triton, ZLUDA, StableFast, DeepCache, OpenVINO, NNCF, IPEX*
|
37 |
+
- Improved prompt parser
|
38 |
+
- Enhanced *Lora*/*LoCon*/*Lyco* code supporting latest trends in training
|
39 |
+
- Built-in queue management
|
40 |
+
- Enterprise level logging and hardened API
|
41 |
+
- Built in installer with automatic updates and dependency management
|
42 |
+
- Modernized UI with theme support and number of built-in themes *(dark and light)*
|
43 |
+
|
44 |
+
<br>
|
45 |
+
|
46 |
+
*Main text2image interface*:
|
47 |
+
![Screenshot-Dark](html/screenshot-text2image.jpg)
|
48 |
+
|
49 |
+
For screenshots and informations on other available themes, see [Themes Wiki](https://github.com/vladmandic/automatic/wiki/Themes)
|
50 |
+
|
51 |
+
<br>
|
52 |
+
|
53 |
+
## Backend support
|
54 |
+
|
55 |
+
**SD.Next** supports two main backends: *Diffusers* and *Original*:
|
56 |
+
|
57 |
+
- **Diffusers**: Based on new [Huggingface Diffusers](https://huggingface.co/docs/diffusers/index) implementation
|
58 |
+
Supports *all* models listed below
|
59 |
+
This backend is set as default for new installations
|
60 |
+
See [wiki article](https://github.com/vladmandic/automatic/wiki/Diffusers) for more information
|
61 |
+
- **Original**: Based on [LDM](https://github.com/Stability-AI/stablediffusion) reference implementation and significantly expanded on by [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
62 |
+
This backend and is fully compatible with most existing functionality and extensions written for *A1111 SDWebUI*
|
63 |
+
Supports **SD 1.x** and **SD 2.x** models
|
64 |
+
All other model types such as *SD-XL, LCM, PixArt, Segmind, Kandinsky, etc.* require backend **Diffusers**
|
65 |
+
|
66 |
+
## Model support
|
67 |
+
|
68 |
+
Additional models will be added as they become available and there is public interest in them
|
69 |
+
|
70 |
+
- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
|
71 |
+
- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models)
|
72 |
+
- [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base, XT 1.0, XT 1.1
|
73 |
+
- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
|
74 |
+
- [Playground](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024 and latest v2.5*
|
75 |
+
- [Stable Cascade](https://github.com/Stability-AI/StableCascade) *Full* and *Lite*
|
76 |
+
- [aMUSEd 256](https://huggingface.co/amused/amused-256) 256 and 512
|
77 |
+
- [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega)
|
78 |
+
- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
|
79 |
+
- [Segmind SegMoE](https://github.com/segmind/segmoe) *SD and SD-XL*
|
80 |
+
- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0*
|
81 |
+
- [PixArt-ฮฑ XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large*
|
82 |
+
- [Warp Wuerstchen](https://huggingface.co/blog/wuertschen)
|
83 |
+
- [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser)
|
84 |
+
- [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large*
|
85 |
+
- [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b)
|
86 |
+
- [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
|
87 |
+
- [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/)
|
88 |
+
- [KOALA 700M](https://github.com/youngwanLEE/sdxl-koala)
|
89 |
+
- [VGen](https://huggingface.co/ali-vilab/i2vgen-xl)
|
90 |
+
|
91 |
+
|
92 |
+
Also supported are modifiers such as:
|
93 |
+
- **LCM** and **Turbo** (*adversarial diffusion distillation*) networks
|
94 |
+
- All **LoRA** types such as LoCon, LyCORIS, HADA, IA3, Lokr, OFT
|
95 |
+
- **IP-Adapters** for SD 1.5 and SD-XL
|
96 |
+
- **InstantID**, **FaceSwap**, **FaceID**, **PhotoMerge**
|
97 |
+
- **AnimateDiff** for SD 1.5
|
98 |
+
|
99 |
+
## Examples
|
100 |
+
|
101 |
+
*IP Adapters*:
|
102 |
+
![Screenshot-IPAdapter](html/screenshot-ipadapter.jpg)
|
103 |
+
|
104 |
+
*Color grading*:
|
105 |
+
![Screenshot-Color](html/screenshot-color.jpg)
|
106 |
+
|
107 |
+
*InstantID*:
|
108 |
+
![Screenshot-InstantID](html/screenshot-instantid.jpg)
|
109 |
+
|
110 |
+
> [!IMPORTANT]
|
111 |
+
> - Loading any model other than standard SD 1.x / SD 2.x requires use of backend **Diffusers**
|
112 |
+
> - Loading any other models using **Original** backend is not supported
|
113 |
+
> - Loading manually download model `.safetensors` files is supported for specified models only (typically SD 1.x / SD 2.x / SD-XL models only)
|
114 |
+
> - For all other model types, use backend **Diffusers** and use built in Model downloader or
|
115 |
+
select model from Networks -> Models -> Reference list in which case it will be auto-downloaded and loaded
|
116 |
+
|
117 |
+
## Platform support
|
118 |
+
|
119 |
+
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
|
120 |
+
- *AMD* GPUs using **ROCm** libraries on *Linux*
|
121 |
+
Support will be extended to *Windows* once AMD releases ROCm for Windows
|
122 |
+
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
|
123 |
+
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
|
124 |
+
This includes support for AMD GPUs that are not supported by native ROCm libraries
|
125 |
+
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
|
126 |
+
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
|
127 |
+
- *ONNX/Olive*
|
128 |
+
|
129 |
+
## Install
|
130 |
+
|
131 |
+
- [Step-by-step install guide](https://github.com/vladmandic/automatic/wiki/Installation)
|
132 |
+
- [Advanced install notes](https://github.com/vladmandic/automatic/wiki/Advanced-Install)
|
133 |
+
- [Common installation errors](https://github.com/vladmandic/automatic/discussions/1627)
|
134 |
+
- [FAQ](https://github.com/vladmandic/automatic/discussions/1011)
|
135 |
+
- If you can't run us locally, try our friends at [RunDuffusion!](https://rundiffusion.com?utm_source=github&utm_medium=referral&utm_campaign=SDNext)
|
136 |
+
|
137 |
+
> [!TIP]
|
138 |
+
> - Server can run with or without virtual environment,
|
139 |
+
Recommended to use `VENV` to avoid library version conflicts with other applications
|
140 |
+
> - **nVidia/CUDA** / **AMD/ROCm** / **Intel/OneAPI** are auto-detected if present and available,
|
141 |
+
For any other use case such as **DirectML**, **ONNX/Olive**, **OpenVINO** specify required parameter explicitly
|
142 |
+
or wrong packages may be installed as installer will assume CPU-only environment
|
143 |
+
> - Full startup sequence is logged in `sdnext.log`,
|
144 |
+
so if you encounter any issues, please check it first
|
145 |
+
|
146 |
+
### Run
|
147 |
+
|
148 |
+
Once SD.Next is installed, simply run `webui.ps1` or `webui.bat` (*Windows*) or `webui.sh` (*Linux or MacOS*)
|
149 |
+
|
150 |
+
List of available parameters, run `webui --help` for the full & up-to-date list:
|
151 |
+
|
152 |
+
Server options:
|
153 |
+
--config CONFIG Use specific server configuration file, default: config.json
|
154 |
+
--ui-config UI_CONFIG Use specific UI configuration file, default: ui-config.json
|
155 |
+
--medvram Split model stages and keep only active part in VRAM, default: False
|
156 |
+
--lowvram Split model components and keep only active part in VRAM, default: False
|
157 |
+
--ckpt CKPT Path to model checkpoint to load immediately, default: None
|
158 |
+
--vae VAE Path to VAE checkpoint to load immediately, default: None
|
159 |
+
--data-dir DATA_DIR Base path where all user data is stored, default:
|
160 |
+
--models-dir MODELS_DIR Base path where all models are stored, default: models
|
161 |
+
--allow-code Allow custom script execution, default: False
|
162 |
+
--share Enable UI accessible through Gradio site, default: False
|
163 |
+
--insecure Enable extensions tab regardless of other options, default: False
|
164 |
+
--use-cpu USE_CPU [USE_CPU ...] Force use CPU for specified modules, default: []
|
165 |
+
--listen Launch web server using public IP address, default: False
|
166 |
+
--port PORT Launch web server with given server port, default: 7860
|
167 |
+
--freeze Disable editing settings
|
168 |
+
--auth AUTH Set access authentication like "user:pwd,user:pwd""
|
169 |
+
--auth-file AUTH_FILE Set access authentication using file, default: None
|
170 |
+
--autolaunch Open the UI URL in the system's default browser upon launch
|
171 |
+
--docs Mount API docs, default: False
|
172 |
+
--api-only Run in API only mode without starting UI
|
173 |
+
--api-log Enable logging of all API requests, default: False
|
174 |
+
--device-id DEVICE_ID Select the default CUDA device to use, default: None
|
175 |
+
--cors-origins CORS_ORIGINS Allowed CORS origins as comma-separated list, default: None
|
176 |
+
--cors-regex CORS_REGEX Allowed CORS origins as regular expression, default: None
|
177 |
+
--tls-keyfile TLS_KEYFILE Enable TLS and specify key file, default: None
|
178 |
+
--tls-certfile TLS_CERTFILE Enable TLS and specify cert file, default: None
|
179 |
+
--tls-selfsign Enable TLS with self-signed certificates, default: False
|
180 |
+
--server-name SERVER_NAME Sets hostname of server, default: None
|
181 |
+
--no-hashing Disable hashing of checkpoints, default: False
|
182 |
+
--no-metadata Disable reading of metadata from models, default: False
|
183 |
+
--disable-queue Disable queues, default: False
|
184 |
+
--subpath SUBPATH Customize the URL subpath for usage with reverse proxy
|
185 |
+
--backend {original,diffusers} force model pipeline type
|
186 |
+
--allowed-paths ALLOWED_PATHS [ALLOWED_PATHS ...] add additional paths to paths allowed for web access
|
187 |
+
|
188 |
+
Setup options:
|
189 |
+
--reset Reset main repository to latest version, default: False
|
190 |
+
--upgrade Upgrade main repository to latest version, default: False
|
191 |
+
--requirements Force re-check of requirements, default: False
|
192 |
+
--quick Bypass version checks, default: False
|
193 |
+
--use-directml Use DirectML if no compatible GPU is detected, default: False
|
194 |
+
--use-openvino Use Intel OpenVINO backend, default: False
|
195 |
+
--use-ipex Force use Intel OneAPI XPU backend, default: False
|
196 |
+
--use-cuda Force use nVidia CUDA backend, default: False
|
197 |
+
--use-rocm Force use AMD ROCm backend, default: False
|
198 |
+
--use-zluda Force use ZLUDA, AMD GPUs only, default: False
|
199 |
+
--use-xformers Force use xFormers cross-optimization, default: False
|
200 |
+
--skip-requirements Skips checking and installing requirements, default: False
|
201 |
+
--skip-extensions Skips running individual extension installers, default: False
|
202 |
+
--skip-git Skips running all GIT operations, default: False
|
203 |
+
--skip-torch Skips running Torch checks, default: False
|
204 |
+
--skip-all Skips running all checks, default: False
|
205 |
+
--skip-env Skips setting of env variables during startup, default: False
|
206 |
+
--experimental Allow unsupported versions of libraries, default: False
|
207 |
+
--reinstall Force reinstallation of all requirements, default: False
|
208 |
+
--test Run test only and exit
|
209 |
+
--version Print version information
|
210 |
+
--ignore Ignore any errors and attempt to continue
|
211 |
+
--safe Run in safe mode with no user extensions
|
212 |
+
|
213 |
+
Logging options:
|
214 |
+
--log LOG Set log file, default: None
|
215 |
+
--debug Run installer with debug logging, default: False
|
216 |
+
--profile Run profiler, default: False
|
217 |
+
|
218 |
+
## Notes
|
219 |
+
|
220 |
+
### Control
|
221 |
+
|
222 |
+
**SD.Next** comes with built-in control for all types of text2image, image2image, video2video and batch processing
|
223 |
+
|
224 |
+
*Control interface*:
|
225 |
+
![Screenshot-Control](html/screenshot-control.jpg)
|
226 |
+
|
227 |
+
*Control processors*:
|
228 |
+
![Screenshot-Process](html/screenshot-processors.jpg)
|
229 |
+
|
230 |
+
*Masking*:
|
231 |
+
![Screenshot-Mask](html/screenshot-mask.jpg)
|
232 |
+
|
233 |
+
### **Extensions**
|
234 |
+
|
235 |
+
SD.Next comes with several extensions pre-installed:
|
236 |
+
|
237 |
+
- [ControlNet](https://github.com/Mikubill/sd-webui-controlnet) (*active in backend: original only*)
|
238 |
+
- [Agent Scheduler](https://github.com/ArtVentureX/sd-webui-agent-scheduler)
|
239 |
+
- [Image Browser](https://github.com/AlUlkesh/stable-diffusion-webui-images-browser)
|
240 |
+
|
241 |
+
### **Collab**
|
242 |
+
|
243 |
+
- We'd love to have additional maintainers (with comes with full repo rights). If you're interested, ping us!
|
244 |
+
- In addition to general cross-platform code, desire is to have a lead for each of the main platforms
|
245 |
+
This should be fully cross-platform, but we'd really love to have additional contributors and/or maintainers to join and help lead the efforts on different platforms
|
246 |
+
|
247 |
+
### **Credits**
|
248 |
+
|
249 |
+
- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for original codebase
|
250 |
+
- Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
|
251 |
+
- Licenses for modules are listed in [Licenses](html/licenses.html)
|
252 |
+
|
253 |
+
### **Evolution**
|
254 |
+
|
255 |
+
<a href="https://star-history.com/#vladmandic/automatic&Date">
|
256 |
+
<picture width=640>
|
257 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=vladmandic/automatic&type=Date&theme=dark" />
|
258 |
+
<img src="https://api.star-history.com/svg?repos=vladmandic/automatic&type=Date" alt="starts" width="320">
|
259 |
+
</picture>
|
260 |
+
</a>
|
261 |
+
|
262 |
+
- [OSS Stats](https://ossinsight.io/analyze/vladmandic/automatic#overview)
|
263 |
+
|
264 |
+
### **Docs**
|
265 |
+
|
266 |
+
If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there,
|
267 |
+
check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it
|
268 |
+
|
269 |
+
- [Wiki](https://github.com/vladmandic/automatic/wiki)
|
270 |
+
- [ReadMe](README.md)
|
271 |
+
- [ToDo](TODO.md)
|
272 |
+
- [ChangeLog](CHANGELOG.md)
|
273 |
+
- [CLI Tools](cli/README.md)
|
274 |
+
|
275 |
+
### **Sponsors**
|
276 |
+
|
277 |
+
<div align="center">
|
278 |
+
<!-- sponsors --><a href="https://github.com/allangrant"><img src="https://github.com/allangrant.png" width="60px" alt="Allan Grant" /></a><a href="https://github.com/BrentOzar"><img src="https://github.com/BrentOzar.png" width="60px" alt="Brent Ozar" /></a><a href="https://github.com/inktomi"><img src="https://github.com/inktomi.png" width="60px" alt="Matthew Runo" /></a><a href="https://github.com/4joeknight4"><img src="https://github.com/4joeknight4.png" width="60px" alt="" /></a><a href="https://github.com/SaladTechnologies"><img src="https://github.com/SaladTechnologies.png" width="60px" alt="Salad Technologies" /></a><a href="https://github.com/mantzaris"><img src="https://github.com/mantzaris.png" width="60px" alt="a.v.mantzaris" /></a><a href="https://github.com/CurseWave"><img src="https://github.com/CurseWave.png" width="60px" alt="" /></a><!-- sponsors -->
|
279 |
+
</div>
|
280 |
+
|
281 |
+
<br>
|
SECURITY.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Security & Privacy Policy
|
2 |
+
|
3 |
+
<br>
|
4 |
+
|
5 |
+
## Issues
|
6 |
+
|
7 |
+
All issues are tracked publicly on GitHub: <https://github.com/vladmandic/automatic/issues>
|
8 |
+
|
9 |
+
<br>
|
10 |
+
|
11 |
+
## Vulnerabilities
|
12 |
+
|
13 |
+
`SD.Next` code base and included dependencies are automatically scanned against known security vulnerabilities
|
14 |
+
|
15 |
+
Any code commit is validated before merge
|
16 |
+
|
17 |
+
- [Dependencies](https://github.com/vladmandic/automatic/security/dependabot)
|
18 |
+
- [Scanning Alerts](https://github.com/vladmandic/automatic/security/code-scanning)
|
19 |
+
|
20 |
+
<br>
|
21 |
+
|
22 |
+
## Privacy
|
23 |
+
|
24 |
+
`SD.Next` app:
|
25 |
+
|
26 |
+
- Is fully self-contained and does not send or share data of any kind with external targets
|
27 |
+
- Does not store any user or system data tracking, user provided inputs (images, video) or detection results
|
28 |
+
- Does not utilize any analytic services (such as Google Analytics)
|
29 |
+
|
30 |
+
`SD.Next` library can establish external connections *only* for following purposes and *only* when explicitly configured by user:
|
31 |
+
|
32 |
+
- Download extensions and themes indexes from automatically updated indexes
|
33 |
+
- Download required packages and repositories from GitHub during installation/upgrade
|
34 |
+
- Download installed/enabled extensions
|
35 |
+
- Download models from CivitAI and/or Huggingface when instructed by user
|
36 |
+
- Submit benchmark info upon user interaction
|
TODO.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO
|
2 |
+
|
3 |
+
Main ToDo list can be found at [GitHub projects](https://github.com/users/vladmandic/projects)
|
4 |
+
|
5 |
+
## Candidates for next release
|
6 |
+
|
7 |
+
- defork
|
8 |
+
- stable diffusion 3.0
|
9 |
+
- ipadapter masking: <https://github.com/huggingface/diffusers/pull/6847>
|
10 |
+
- x-adapter: <https://github.com/showlab/X-Adapter>
|
11 |
+
- async lowvram: <https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855>
|
12 |
+
- init latents: variations, img2img
|
13 |
+
- diffusers public callbacks
|
14 |
+
- remove builtin: controlnet
|
15 |
+
- remove builtin: image-browser
|
16 |
+
|
17 |
+
## Control missing features
|
18 |
+
|
19 |
+
- second pass: <https://github.com/vladmandic/automatic/issues/2783>
|
20 |
+
- control api
|
cli/README.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable-Diffusion Productivity Scripts
|
2 |
+
|
3 |
+
Note: All scripts have built-in `--help` parameter that can be used to get more information
|
4 |
+
|
5 |
+
<br>
|
6 |
+
|
7 |
+
## Main Scripts
|
8 |
+
|
9 |
+
### Generate
|
10 |
+
|
11 |
+
Text-to-image with all of the possible parameters
|
12 |
+
Supports upsampling, face restoration and grid creation
|
13 |
+
> python generate.py
|
14 |
+
|
15 |
+
By default uses parameters from `generate.json`
|
16 |
+
|
17 |
+
Parameters that are not specified will be randomized:
|
18 |
+
|
19 |
+
- Prompt will be dynamically created from template of random samples: `random.json`
|
20 |
+
- Sampler/Scheduler will be randomly picked from available ones
|
21 |
+
- CFG Scale set to 5-10
|
22 |
+
|
23 |
+
### Train
|
24 |
+
|
25 |
+
Combined pipeline for **embeddings**, **lora**, **lycoris**, **dreambooth** and **hypernetwork**
|
26 |
+
Optionally runs several image processing steps before training:
|
27 |
+
|
28 |
+
- keep original image
|
29 |
+
- detect and extract face
|
30 |
+
- detect and extract body
|
31 |
+
- detect blur
|
32 |
+
- detect dynamic range
|
33 |
+
- attempt to upscale low resolution images
|
34 |
+
- attempt to restore quality of low quality images
|
35 |
+
- automatically generate captions using interrogate
|
36 |
+
- resize image
|
37 |
+
- square image
|
38 |
+
- run image segmentation to remove background
|
39 |
+
|
40 |
+
> python train.py
|
41 |
+
|
42 |
+
<br>
|
43 |
+
|
44 |
+
## Auxiliary Scripts
|
45 |
+
|
46 |
+
### Benchmark
|
47 |
+
|
48 |
+
> python run-benchmark.py
|
49 |
+
|
50 |
+
### Create Previews
|
51 |
+
|
52 |
+
Create previews for **embeddings**, **lora**, **lycoris**, **dreambooth** and **hypernetwork**
|
53 |
+
|
54 |
+
> python create-previews.py
|
55 |
+
|
56 |
+
## Image Grid
|
57 |
+
|
58 |
+
> python image-grid.py
|
59 |
+
|
60 |
+
### Image Watermark
|
61 |
+
|
62 |
+
Create invisible image watermark and remove existing EXIF tags
|
63 |
+
|
64 |
+
> python image-watermark.py
|
65 |
+
|
66 |
+
### Image Interrogate
|
67 |
+
|
68 |
+
Runs CLiP and Booru image interrogation
|
69 |
+
|
70 |
+
> python image-interrogate.py
|
71 |
+
|
72 |
+
### Palette Extract
|
73 |
+
|
74 |
+
Extract color palette from image(s)
|
75 |
+
|
76 |
+
> python image-palette.py
|
77 |
+
|
78 |
+
### Prompt Ideas
|
79 |
+
|
80 |
+
Generate complex prompt ideas
|
81 |
+
|
82 |
+
> python prompt-ideas.py
|
83 |
+
|
84 |
+
### Prompt Promptist
|
85 |
+
|
86 |
+
Attempts to beautify the provided prompt
|
87 |
+
|
88 |
+
> python prompt-promptist.py
|
89 |
+
|
90 |
+
### Video Extract
|
91 |
+
|
92 |
+
Extract frames from video files
|
93 |
+
|
94 |
+
> python video-extract.py
|
95 |
+
|
96 |
+
<br>
|
97 |
+
|
98 |
+
## Utility Scripts
|
99 |
+
|
100 |
+
### SDAPI
|
101 |
+
|
102 |
+
Utility module that handles async communication to Automatic API endpoints
|
103 |
+
Note: Requires SD API
|
104 |
+
|
105 |
+
Can be used to manually execute specific commands:
|
106 |
+
> python sdapi.py progress
|
107 |
+
> python sdapi.py interrupt
|
108 |
+
> python sdapi.py shutdown
|
cli/clone.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import git
|
5 |
+
from rich import console, progress
|
6 |
+
|
7 |
+
|
8 |
+
class GitRemoteProgress(git.RemoteProgress):
|
9 |
+
OP_CODES = ["BEGIN", "CHECKING_OUT", "COMPRESSING", "COUNTING", "END", "FINDING_SOURCES", "RECEIVING", "RESOLVING", "WRITING"]
|
10 |
+
OP_CODE_MAP = { getattr(git.RemoteProgress, _op_code): _op_code for _op_code in OP_CODES }
|
11 |
+
|
12 |
+
def __init__(self, url, folder) -> None:
|
13 |
+
super().__init__()
|
14 |
+
self.url = url
|
15 |
+
self.folder = folder
|
16 |
+
self.progressbar = progress.Progress(
|
17 |
+
progress.SpinnerColumn(),
|
18 |
+
progress.TextColumn("[cyan][progress.description]{task.description}"),
|
19 |
+
progress.BarColumn(),
|
20 |
+
progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
21 |
+
progress.TimeRemainingColumn(),
|
22 |
+
progress.TextColumn("[yellow]<{task.fields[url]}>"),
|
23 |
+
progress.TextColumn("{task.fields[message]}"),
|
24 |
+
console=console.Console(),
|
25 |
+
transient=False,
|
26 |
+
)
|
27 |
+
self.progressbar.start()
|
28 |
+
self.active_task = None
|
29 |
+
|
30 |
+
def __del__(self) -> None:
|
31 |
+
self.progressbar.stop()
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def get_curr_op(cls, op_code: int) -> str:
|
35 |
+
op_code_masked = op_code & cls.OP_MASK
|
36 |
+
return cls.OP_CODE_MAP.get(op_code_masked, "?").title()
|
37 |
+
|
38 |
+
def update(self, op_code: int, cur_count: str | float, max_count: str | float | None = None, message: str | None = "") -> None:
|
39 |
+
if op_code & self.BEGIN:
|
40 |
+
self.curr_op = self.get_curr_op(op_code) # pylint: disable=attribute-defined-outside-init
|
41 |
+
self.active_task = self.progressbar.add_task(description=self.curr_op, total=max_count, message=message, url=self.url)
|
42 |
+
self.progressbar.update(task_id=self.active_task, completed=cur_count, message=message)
|
43 |
+
if op_code & self.END:
|
44 |
+
self.progressbar.update(task_id=self.active_task, message=f"[bright_black]{message}")
|
45 |
+
|
46 |
+
|
47 |
+
def clone(url: str, folder: str):
|
48 |
+
git.Repo.clone_from(
|
49 |
+
url=url,
|
50 |
+
to_path=folder,
|
51 |
+
progress=GitRemoteProgress(url=url, folder=folder),
|
52 |
+
multi_options=['--config core.compression=0', '--config core.loosecompression=0', '--config pack.window=0'],
|
53 |
+
allow_unsafe_options=True,
|
54 |
+
depth=1,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
import argparse
|
60 |
+
parser = argparse.ArgumentParser(description = 'downloader')
|
61 |
+
parser.add_argument('--url', required=True, help="download url, required")
|
62 |
+
parser.add_argument('--folder', required=False, help="output folder, default: autodetect")
|
63 |
+
args = parser.parse_args()
|
64 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
|
65 |
+
log = logging.getLogger(__name__)
|
66 |
+
try:
|
67 |
+
if not args.url.startswith('http'):
|
68 |
+
raise ValueError(f'invalid url: {args.url}')
|
69 |
+
f = args.url.split('/')[-1].split('.')[0] if args.folder is None else args.folder
|
70 |
+
if os.path.exists(f):
|
71 |
+
raise FileExistsError(f'folder already exists: {f}')
|
72 |
+
log.info(f'Clone start: url={args.url} folder={f}')
|
73 |
+
clone(url=args.url, folder=f)
|
74 |
+
log.info(f'Clone complete: url={args.url} folder={f}')
|
75 |
+
except KeyboardInterrupt:
|
76 |
+
log.warning(f'Clone cancelled: url={args.url} folder={f}')
|
77 |
+
except Exception as e:
|
78 |
+
log.error(f'Clone: url={args.url} {e}')
|
cli/create-previews.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# pylint: disable=no-member
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import logging
|
8 |
+
import importlib
|
9 |
+
import asyncio
|
10 |
+
import argparse
|
11 |
+
from pathlib import Path
|
12 |
+
from util import Map, log
|
13 |
+
from sdapi import get, post, close
|
14 |
+
from generate import generate # pylint: disable=import-error
|
15 |
+
grid = importlib.import_module('image-grid').grid
|
16 |
+
|
17 |
+
|
18 |
+
options = Map({
|
19 |
+
# used by extra networks
|
20 |
+
'prompt': 'photo of <keyword> <embedding>, photograph, posing, pose, high detailed, intricate, elegant, sharp focus, skin texture, looking forward, facing camera, 135mm, shot on dslr, canon 5d, 4k, modelshoot style, cinematic lighting',
|
21 |
+
# used by models
|
22 |
+
'prompts': [
|
23 |
+
('photo citiscape', 'cityscape during night, photorealistic, high detailed, sharp focus, depth of field, 4k'),
|
24 |
+
('photo car', 'photo of a sports car, high detailed, sharp focus, dslr, cinematic lighting, realistic'),
|
25 |
+
('photo woman', 'portrait photo of beautiful woman, high detailed, dslr, 35mm'),
|
26 |
+
('photo naked', 'full body photo of beautiful sexy naked woman, high detailed, dslr, 35mm'),
|
27 |
+
|
28 |
+
('photo taylor', 'portrait photo of beautiful woman taylor swift, high detailed, sharp focus, depth of field, dslr, 35mm <lora:taylor-swift:1>'),
|
29 |
+
('photo ti-mia', 'portrait photo of beautiful woman "ti-mia", naked, high detailed, dslr, 35mm'),
|
30 |
+
('photo ti-vlado', 'portrait photo of man "ti-vlado", high detailed, dslr, 35mm'),
|
31 |
+
('photo lora-vlado', 'portrait photo of man vlado, high detailed, dslr, 35mm <lora:vlado-original:1>'),
|
32 |
+
|
33 |
+
('wlop', 'a stunning portrait of sexy teen girl in a wet t-shirt, vivid color palette, digital painting, octane render, highly detailed, particles, light effect, volumetric lighting, art by wlop'),
|
34 |
+
('greg rutkowski', 'beautiful woman, high detailed, sharp focus, depth of field, 4k, art by greg rutkowski'),
|
35 |
+
('carne griffiths', 'beautiful woman taylor swift, high detailed, sharp focus, depth of field, art by carne griffiths <lora:taylor-swift:1>'),
|
36 |
+
('carne griffiths', 'man vlado, high detailed, sharp focus, depth of field, art by carne griffiths <lora:vlado-full:1>'),
|
37 |
+
],
|
38 |
+
# save format
|
39 |
+
'format': '.jpg',
|
40 |
+
# used by generate script
|
41 |
+
'paths': {
|
42 |
+
"root": "/mnt/c/Users/mandi/OneDrive/Generative/Generate",
|
43 |
+
"generate": "image",
|
44 |
+
"upscale": "upscale",
|
45 |
+
"grid": "grid",
|
46 |
+
},
|
47 |
+
# generate params
|
48 |
+
'generate': {
|
49 |
+
'restore_faces': True,
|
50 |
+
'prompt': '',
|
51 |
+
'negative_prompt': 'foggy, blurry, blurred, duplicate, ugly, mutilated, mutation, mutated, out of frame, bad anatomy, disfigured, deformed, censored, low res, low resolution, watermark, text, poorly drawn face, poorly drawn hands, signature',
|
52 |
+
'steps': 20,
|
53 |
+
'batch_size': 2,
|
54 |
+
'n_iter': 1,
|
55 |
+
'seed': -1,
|
56 |
+
'sampler_name': 'UniPC',
|
57 |
+
'cfg_scale': 6,
|
58 |
+
'width': 512,
|
59 |
+
'height': 512,
|
60 |
+
},
|
61 |
+
'lora': {
|
62 |
+
'strength': 1.0,
|
63 |
+
},
|
64 |
+
'hypernetwork': {
|
65 |
+
'keyword': '',
|
66 |
+
'strength': 1.0,
|
67 |
+
},
|
68 |
+
})
|
69 |
+
|
70 |
+
|
71 |
+
def preview_exists(folder, model):
|
72 |
+
model = os.path.splitext(model)[0]
|
73 |
+
for suffix in ['', '.preview']:
|
74 |
+
for ext in ['.jpg', '.png', '.webp']:
|
75 |
+
fn = os.path.join(folder, f'{model}{suffix}{ext}')
|
76 |
+
if os.path.exists(fn):
|
77 |
+
return True
|
78 |
+
return False
|
79 |
+
|
80 |
+
|
81 |
+
async def preview_models(params):
|
82 |
+
data = await get('/sdapi/v1/sd-models')
|
83 |
+
allmodels = [m['title'] for m in data]
|
84 |
+
models = []
|
85 |
+
excluded = []
|
86 |
+
for m in allmodels: # loop through all registered models
|
87 |
+
ok = True
|
88 |
+
for e in params.exclude: # check if model is excluded
|
89 |
+
if e in m:
|
90 |
+
excluded.append(m)
|
91 |
+
ok = False
|
92 |
+
break
|
93 |
+
if ok:
|
94 |
+
short = m.split(' [')[0]
|
95 |
+
short = short.replace('.ckpt', '').replace('.safetensors', '')
|
96 |
+
models.append(short)
|
97 |
+
if len(params.input) > 0: # check if model is included in cmd line
|
98 |
+
filtered = []
|
99 |
+
for m in params.input:
|
100 |
+
if m in models:
|
101 |
+
filtered.append(m)
|
102 |
+
else:
|
103 |
+
log.error({ 'model not found': m })
|
104 |
+
return
|
105 |
+
models = filtered
|
106 |
+
log.info({ 'models preview' })
|
107 |
+
log.info({ 'models': len(models), 'excluded': len(excluded) })
|
108 |
+
opt = await get('/sdapi/v1/options')
|
109 |
+
log.info({ 'total jobs': len(models) * options.generate.batch_size, 'per-model': options.generate.batch_size })
|
110 |
+
log.info(json.dumps(options, indent=2))
|
111 |
+
for model in models:
|
112 |
+
if preview_exists(opt['ckpt_dir'], model) and len(params.input) == 0: # if model preview exists and not manually included
|
113 |
+
log.info({ 'model preview exists': model })
|
114 |
+
continue
|
115 |
+
fn = os.path.join(opt['ckpt_dir'], os.path.splitext(model)[0] + options.format)
|
116 |
+
log.info({ 'model load': model })
|
117 |
+
|
118 |
+
opt['sd_model_checkpoint'] = model
|
119 |
+
del opt['sd_lora']
|
120 |
+
del opt['sd_lyco']
|
121 |
+
await post('/sdapi/v1/options', opt)
|
122 |
+
opt = await get('/sdapi/v1/options')
|
123 |
+
images = []
|
124 |
+
labels = []
|
125 |
+
t0 = time.time()
|
126 |
+
for label, p in options.prompts:
|
127 |
+
options.generate.prompt = p
|
128 |
+
log.info({ 'model generating': model, 'label': label, 'prompt': options.generate.prompt })
|
129 |
+
data = await generate(options = options, quiet=True)
|
130 |
+
if 'image' in data:
|
131 |
+
for img in data['image']:
|
132 |
+
images.append(img)
|
133 |
+
labels.append(label)
|
134 |
+
else:
|
135 |
+
log.error({ 'model': model, 'error': data })
|
136 |
+
t1 = time.time()
|
137 |
+
if len(images) == 0:
|
138 |
+
log.error({ 'model': model, 'error': 'no images generated' })
|
139 |
+
continue
|
140 |
+
image = grid(images = images, labels = labels, border = 8)
|
141 |
+
log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
|
142 |
+
image.save(fn)
|
143 |
+
t = t1 - t0
|
144 |
+
its = 1.0 * options.generate.steps * len(images) / t
|
145 |
+
log.info({ 'model preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
|
146 |
+
|
147 |
+
opt = await get('/sdapi/v1/options')
|
148 |
+
if opt['sd_model_checkpoint'] != params.model:
|
149 |
+
log.info({ 'model set default': params.model })
|
150 |
+
opt['sd_model_checkpoint'] = params.model
|
151 |
+
del opt['sd_lora']
|
152 |
+
del opt['sd_lyco']
|
153 |
+
await post('/sdapi/v1/options', opt)
|
154 |
+
|
155 |
+
|
156 |
+
async def lora(params):
|
157 |
+
opt = await get('/sdapi/v1/options')
|
158 |
+
folder = opt['lora_dir']
|
159 |
+
if not os.path.exists(folder):
|
160 |
+
log.error({ 'lora directory not found': folder })
|
161 |
+
return
|
162 |
+
models1 = list(Path(folder).glob('**/*.safetensors'))
|
163 |
+
models2 = list(Path(folder).glob('**/*.ckpt'))
|
164 |
+
models = [os.path.splitext(f)[0] for f in models1 + models2]
|
165 |
+
log.info({ 'loras': len(models) })
|
166 |
+
for model in models:
|
167 |
+
if preview_exists('', model) and len(params.input) == 0: # if model preview exists and not manually included
|
168 |
+
log.info({ 'lora preview exists': model })
|
169 |
+
continue
|
170 |
+
fn = model + options.format
|
171 |
+
model = os.path.basename(model)
|
172 |
+
images = []
|
173 |
+
labels = []
|
174 |
+
t0 = time.time()
|
175 |
+
keywords = re.sub(r'\d', '', model)
|
176 |
+
keywords = keywords.replace('-v', ' ').replace('-', ' ').strip().split(' ')
|
177 |
+
keyword = '\"' + '\" \"'.join(keywords) + '\"'
|
178 |
+
options.generate.prompt = options.prompt.replace('<keyword>', keyword)
|
179 |
+
options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
|
180 |
+
options.generate.prompt += f' <lora:{model}:{options.lora.strength}>'
|
181 |
+
log.info({ 'lora generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
|
182 |
+
data = await generate(options = options, quiet=True)
|
183 |
+
if 'image' in data:
|
184 |
+
for img in data['image']:
|
185 |
+
images.append(img)
|
186 |
+
labels.append(keyword)
|
187 |
+
else:
|
188 |
+
log.error({ 'lora': model, 'keyword': keyword, 'error': data })
|
189 |
+
t1 = time.time()
|
190 |
+
if len(images) == 0:
|
191 |
+
log.error({ 'model': model, 'error': 'no images generated' })
|
192 |
+
continue
|
193 |
+
image = grid(images = images, labels = labels, border = 8)
|
194 |
+
log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
|
195 |
+
image.save(fn)
|
196 |
+
t = t1 - t0
|
197 |
+
its = 1.0 * options.generate.steps * len(images) / t
|
198 |
+
log.info({ 'lora preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
|
199 |
+
|
200 |
+
|
201 |
+
async def lyco(params):
|
202 |
+
opt = await get('/sdapi/v1/options')
|
203 |
+
folder = opt['lyco_dir']
|
204 |
+
if not os.path.exists(folder):
|
205 |
+
log.error({ 'lyco directory not found': folder })
|
206 |
+
return
|
207 |
+
models1 = list(Path(folder).glob('**/*.safetensors'))
|
208 |
+
models2 = list(Path(folder).glob('**/*.ckpt'))
|
209 |
+
models = [os.path.splitext(f)[0] for f in models1 + models2]
|
210 |
+
log.info({ 'lycos': len(models) })
|
211 |
+
for model in models:
|
212 |
+
if preview_exists('', model) and len(params.input) == 0: # if model preview exists and not manually included
|
213 |
+
log.info({ 'lyco preview exists': model })
|
214 |
+
continue
|
215 |
+
fn = model + options.format
|
216 |
+
model = os.path.basename(model)
|
217 |
+
images = []
|
218 |
+
labels = []
|
219 |
+
t0 = time.time()
|
220 |
+
keywords = re.sub(r'\d', '', model)
|
221 |
+
keywords = keywords.replace('-v', ' ').replace('-', ' ').strip().split(' ')
|
222 |
+
keyword = '\"' + '\" \"'.join(keywords) + '\"'
|
223 |
+
options.generate.prompt = options.prompt.replace('<keyword>', keyword)
|
224 |
+
options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
|
225 |
+
options.generate.prompt += f' <lyco:{model}:{options.lora.strength}>'
|
226 |
+
log.info({ 'lyco generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
|
227 |
+
data = await generate(options = options, quiet=True)
|
228 |
+
if 'image' in data:
|
229 |
+
for img in data['image']:
|
230 |
+
images.append(img)
|
231 |
+
labels.append(keyword)
|
232 |
+
else:
|
233 |
+
log.error({ 'lyco': model, 'keyword': keyword, 'error': data })
|
234 |
+
t1 = time.time()
|
235 |
+
if len(images) == 0:
|
236 |
+
log.error({ 'model': model, 'error': 'no images generated' })
|
237 |
+
continue
|
238 |
+
image = grid(images = images, labels = labels, border = 8)
|
239 |
+
log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
|
240 |
+
image.save(fn)
|
241 |
+
t = t1 - t0
|
242 |
+
its = 1.0 * options.generate.steps * len(images) / t
|
243 |
+
log.info({ 'lyco preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
|
244 |
+
|
245 |
+
|
246 |
+
async def hypernetwork(params):
|
247 |
+
opt = await get('/sdapi/v1/options')
|
248 |
+
folder = opt['hypernetwork_dir']
|
249 |
+
if not os.path.exists(folder):
|
250 |
+
log.error({ 'hypernetwork directory not found': folder })
|
251 |
+
return
|
252 |
+
models = [os.path.splitext(f)[0] for f in Path(folder).glob('**/*.pt')]
|
253 |
+
log.info({ 'hypernetworks': len(models) })
|
254 |
+
for model in models:
|
255 |
+
if preview_exists(folder, model) and len(params.input) == 0: # if model preview exists and not manually included
|
256 |
+
log.info({ 'hypernetwork preview exists': model })
|
257 |
+
continue
|
258 |
+
fn = os.path.join(folder, model + options.format)
|
259 |
+
images = []
|
260 |
+
labels = []
|
261 |
+
t0 = time.time()
|
262 |
+
keyword = options.hypernetwork.keyword
|
263 |
+
options.generate.prompt = options.prompt.replace('<keyword>', options.hypernetwork.keyword)
|
264 |
+
options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
|
265 |
+
options.generate.prompt = f' <hypernet:{model}:{options.hypernetwork.strength}> ' + options.generate.prompt
|
266 |
+
log.info({ 'hypernetwork generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
|
267 |
+
data = await generate(options = options, quiet=True)
|
268 |
+
if 'image' in data:
|
269 |
+
for img in data['image']:
|
270 |
+
images.append(img)
|
271 |
+
labels.append(keyword)
|
272 |
+
else:
|
273 |
+
log.error({ 'hypernetwork': model, 'keyword': keyword, 'error': data })
|
274 |
+
t1 = time.time()
|
275 |
+
if len(images) == 0:
|
276 |
+
log.error({ 'model': model, 'error': 'no images generated' })
|
277 |
+
continue
|
278 |
+
image = grid(images = images, labels = labels, border = 8)
|
279 |
+
log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
|
280 |
+
image.save(fn)
|
281 |
+
t = t1 - t0
|
282 |
+
its = 1.0 * options.generate.steps * len(images) / t
|
283 |
+
log.info({ 'hypernetwork preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
|
284 |
+
|
285 |
+
|
286 |
+
async def embedding(params):
|
287 |
+
opt = await get('/sdapi/v1/options')
|
288 |
+
folder = opt['embeddings_dir']
|
289 |
+
if not os.path.exists(folder):
|
290 |
+
log.error({ 'embeddings directory not found': folder })
|
291 |
+
return
|
292 |
+
models = [os.path.splitext(f)[0] for f in Path(folder).glob('**/*.pt')]
|
293 |
+
log.info({ 'embeddings': len(models) })
|
294 |
+
for model in models:
|
295 |
+
if preview_exists(folder, model) and len(params.input) == 0: # if model preview exists and not manually included
|
296 |
+
log.info({ 'embedding preview exists': model })
|
297 |
+
continue
|
298 |
+
fn = os.path.join(folder, model + '.preview' + options.format)
|
299 |
+
images = []
|
300 |
+
labels = []
|
301 |
+
t0 = time.time()
|
302 |
+
keyword = '\"' + re.sub(r'\d', '', model) + '\"'
|
303 |
+
options.generate.batch_size = 4
|
304 |
+
options.generate.prompt = options.prompt.replace('<keyword>', keyword)
|
305 |
+
options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
|
306 |
+
log.info({ 'embedding generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
|
307 |
+
data = await generate(options = options, quiet=True)
|
308 |
+
if 'image' in data:
|
309 |
+
for img in data['image']:
|
310 |
+
images.append(img)
|
311 |
+
labels.append(keyword)
|
312 |
+
else:
|
313 |
+
log.error({ 'embeding': model, 'keyword': keyword, 'error': data })
|
314 |
+
t1 = time.time()
|
315 |
+
if len(images) == 0:
|
316 |
+
log.error({ 'model': model, 'error': 'no images generated' })
|
317 |
+
continue
|
318 |
+
image = grid(images = images, labels = labels, border = 8)
|
319 |
+
log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
|
320 |
+
image.save(fn)
|
321 |
+
t = t1 - t0
|
322 |
+
its = 1.0 * options.generate.steps * len(images) / t
|
323 |
+
log.info({ 'embeding preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
|
324 |
+
|
325 |
+
|
326 |
+
async def create_previews(params):
|
327 |
+
await preview_models(params)
|
328 |
+
await lora(params)
|
329 |
+
await lyco(params)
|
330 |
+
await hypernetwork(params)
|
331 |
+
await embedding(params)
|
332 |
+
await close()
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == '__main__':
|
336 |
+
parser = argparse.ArgumentParser(description = 'generate model previews')
|
337 |
+
parser.add_argument('--model', default='best/icbinp-icantbelieveIts-final.safetensors [73f48afbdc]', help="model used to create extra network previews")
|
338 |
+
parser.add_argument('--exclude', default=['sd-v20', 'sd-v21', 'inpainting', 'pix2pix'], help="exclude models with keywords")
|
339 |
+
parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
|
340 |
+
parser.add_argument('input', type = str, nargs = '*')
|
341 |
+
args = parser.parse_args()
|
342 |
+
if args.debug:
|
343 |
+
log.setLevel(logging.DEBUG)
|
344 |
+
log.debug({ 'debug': True })
|
345 |
+
log.debug({ 'args': args.__dict__ })
|
346 |
+
asyncio.run(create_previews(args))
|
cli/download.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import argparse
|
5 |
+
import tempfile
|
6 |
+
import urllib
|
7 |
+
import requests
|
8 |
+
import urllib3
|
9 |
+
import rich.progress as p
|
10 |
+
from rich import print # pylint: disable=redefined-builtin
|
11 |
+
|
12 |
+
|
13 |
+
pbar = p.Progress(p.TextColumn('[cyan]{task.description}'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn())
|
14 |
+
headers = {
|
15 |
+
'Content-type': 'application/json',
|
16 |
+
'User-Agent': 'Mozilla/5.0',
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def get_filename(args, res):
|
21 |
+
content_fn = (res.headers.get('content-disposition', '').split('filename=')[1]).strip().strip('\"') if 'filename=' in res.headers.get('content-disposition', '') else None
|
22 |
+
return args.file or content_fn or next(tempfile._get_candidate_names()) # pylint: disable=protected-access
|
23 |
+
|
24 |
+
|
25 |
+
def download_requests(args):
|
26 |
+
res = requests.get(args.url, timeout=30, headers=headers, verify=False, allow_redirects=True, stream=True)
|
27 |
+
content_length = int(res.headers.get('content-length', 0))
|
28 |
+
fn = get_filename(args, res)
|
29 |
+
print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=requests block={args.block}')
|
30 |
+
with open(fn, 'wb') as f:
|
31 |
+
with pbar:
|
32 |
+
task = pbar.add_task(description="Download starting", total=content_length)
|
33 |
+
for data in res.iter_content(args.block):
|
34 |
+
f.write(data)
|
35 |
+
pbar.update(task, advance=args.block, description="Downloading")
|
36 |
+
return fn
|
37 |
+
|
38 |
+
|
39 |
+
def download_urllib(args):
|
40 |
+
fn = ''
|
41 |
+
req = urllib.request.Request(args.url, headers=headers)
|
42 |
+
res = urllib.request.urlopen(req)
|
43 |
+
res.getheader('content-length')
|
44 |
+
content_length = int(res.getheader('content-length') or 0)
|
45 |
+
fn = get_filename(args, res)
|
46 |
+
print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=urllib block={args.block}')
|
47 |
+
with open(fn, 'wb') as f:
|
48 |
+
with pbar:
|
49 |
+
task = pbar.add_task(description="Download starting", total=content_length)
|
50 |
+
while True:
|
51 |
+
buf = res.read(args.block)
|
52 |
+
if not buf:
|
53 |
+
break
|
54 |
+
f.write(buf)
|
55 |
+
pbar.update(task, advance=args.block, description="Downloading")
|
56 |
+
return fn
|
57 |
+
|
58 |
+
|
59 |
+
def download_urllib3(args):
|
60 |
+
http_pool = urllib3.PoolManager()
|
61 |
+
res = http_pool.request('GET', args.url, preload_content=False, headers=headers)
|
62 |
+
fn = get_filename(args, res)
|
63 |
+
content_length = int(res.headers.get('content-length', 0))
|
64 |
+
print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=urllib3 block={args.block}')
|
65 |
+
with open(fn, 'wb') as f:
|
66 |
+
with pbar:
|
67 |
+
task = pbar.add_task(description="Download starting", total=content_length)
|
68 |
+
while True:
|
69 |
+
buf = res.read(args.block)
|
70 |
+
if not buf:
|
71 |
+
break
|
72 |
+
f.write(buf)
|
73 |
+
pbar.update(task, advance=args.block, description="Downloading")
|
74 |
+
return fn
|
75 |
+
|
76 |
+
|
77 |
+
def download_httpx(args):
|
78 |
+
try:
|
79 |
+
import httpx
|
80 |
+
except ImportError:
|
81 |
+
print('httpx is not installed')
|
82 |
+
return None
|
83 |
+
with httpx.stream("GET", args.url, headers=headers, verify=False, follow_redirects=True) as res:
|
84 |
+
fn = get_filename(args, res)
|
85 |
+
content_length = int(res.headers.get('content-length', 0))
|
86 |
+
print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=httpx block=internal')
|
87 |
+
with open(fn, 'wb') as f:
|
88 |
+
with pbar:
|
89 |
+
task = pbar.add_task(description="Download starting", total=content_length)
|
90 |
+
for buf in res.iter_bytes():
|
91 |
+
f.write(buf)
|
92 |
+
pbar.update(task, advance=args.block, description="Downloading")
|
93 |
+
return fn
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
parser = argparse.ArgumentParser(description = 'downloader')
|
98 |
+
parser.add_argument('--url', required=True, help="download url, required")
|
99 |
+
parser.add_argument('--file', required=False, help="output file, default: autodetect")
|
100 |
+
parser.add_argument('--lib', required=False, default='requests', choices=['urllib', 'urllib3', 'requests', 'httpx'], help="download mode, default: %(default)s")
|
101 |
+
parser.add_argument('--block', required=False, type=int, default=16384, help="download block size, default: %(default)s")
|
102 |
+
parsed = parser.parse_args()
|
103 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
104 |
+
try:
|
105 |
+
t0 = time.time()
|
106 |
+
if parsed.lib == 'requests':
|
107 |
+
filename = download_requests(parsed)
|
108 |
+
elif parsed.lib == 'urllib':
|
109 |
+
filename = download_urllib(parsed)
|
110 |
+
elif parsed.lib == 'urllib3':
|
111 |
+
filename = download_urllib3(parsed)
|
112 |
+
elif parsed.lib == 'httpx':
|
113 |
+
filename = download_httpx(parsed)
|
114 |
+
else:
|
115 |
+
print(f'unknown download library: {parsed.lib}')
|
116 |
+
exit(1)
|
117 |
+
t1 = time.time()
|
118 |
+
if filename is None:
|
119 |
+
print(f'download error: args={parsed}')
|
120 |
+
exit(1)
|
121 |
+
speed = round(os.path.getsize(filename) / (t1 - t0) / 1024 / 1024, 3)
|
122 |
+
print(f'download complete: url={parsed.url} file={filename} speed={speed} mb/s')
|
123 |
+
except KeyboardInterrupt:
|
124 |
+
print(f'download cancelled: args={parsed}')
|
125 |
+
except Exception as e:
|
126 |
+
print(f'download error: args={parsed} {e}')
|
cli/gen-styles.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/env python
|
2 |
+
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
import argparse
|
7 |
+
import requests
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
options = {
|
12 |
+
"negative_prompt": "",
|
13 |
+
"steps": 20,
|
14 |
+
"batch_size": 1,
|
15 |
+
"n_iter": 1,
|
16 |
+
"seed": -1,
|
17 |
+
"sampler_name": "UniPC",
|
18 |
+
"cfg_scale": 6,
|
19 |
+
"width": 512,
|
20 |
+
"height": 512,
|
21 |
+
"save_images": False,
|
22 |
+
"send_images": True,
|
23 |
+
}
|
24 |
+
styles = []
|
25 |
+
|
26 |
+
|
27 |
+
def pil_to_b64(img: Image, size: int, quality: int):
|
28 |
+
img = img.convert('RGB')
|
29 |
+
img = img.resize((size, size))
|
30 |
+
buffer = io.BytesIO()
|
31 |
+
img.save(buffer, format="JPEG", quality=quality)
|
32 |
+
b64encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
33 |
+
return f'data:image/jpeg;base64,{b64encoded}'
|
34 |
+
|
35 |
+
|
36 |
+
def post(endpoint: str, dct: dict = None):
|
37 |
+
req = requests.post(endpoint, json = dct, timeout=300, verify=False)
|
38 |
+
if req.status_code != 200:
|
39 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
40 |
+
else:
|
41 |
+
return req.json()
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
parser = argparse.ArgumentParser(description = 'gen-styles.py')
|
46 |
+
parser.add_argument('--input', type=str, required=True, help="input text file with one line per prompt")
|
47 |
+
parser.add_argument('--output', type=str, required=True, help="output json file")
|
48 |
+
parser.add_argument('--nopreviews', default=False, action='store_true', help = 'generate previews')
|
49 |
+
parser.add_argument('--prompt', type=str, required=False, default='girl walking in a city', help="applied prompt when generating previews")
|
50 |
+
parser.add_argument('--size', type=int, default=128, help="image size for previews")
|
51 |
+
parser.add_argument('--quality', type=int, default=35, help="image quality for previews")
|
52 |
+
parser.add_argument('--url', type=str, required=False, default='http://127.0.0.1:7860', help="sd.next server url")
|
53 |
+
args = parser.parse_args()
|
54 |
+
with open(args.input, encoding='utf-8') as f:
|
55 |
+
lines = f.readlines()
|
56 |
+
for line in lines:
|
57 |
+
line = line.strip().replace('\n', '')
|
58 |
+
if len(line) == 0:
|
59 |
+
continue
|
60 |
+
print(f'processing: {line}')
|
61 |
+
if not args.nopreviews:
|
62 |
+
options['prompt'] = f'{line} {args.prompt}'
|
63 |
+
data = post(f'{args.url}/sdapi/v1/txt2img', options)
|
64 |
+
if 'error' in data:
|
65 |
+
print(f'error: {data}')
|
66 |
+
continue
|
67 |
+
b64str = data['images'][0].split(',',1)[0]
|
68 |
+
image = Image.open(io.BytesIO(base64.b64decode(b64str)))
|
69 |
+
else:
|
70 |
+
image = None
|
71 |
+
styles.append({
|
72 |
+
'name': line,
|
73 |
+
'prompt': line + ' {prompt}',
|
74 |
+
'negative': '',
|
75 |
+
'extra': '',
|
76 |
+
'preview': pil_to_b64(image, args.size, args.quality) if image is not None else '',
|
77 |
+
})
|
78 |
+
with open(args.output, 'w', encoding='utf-8') as outfile:
|
79 |
+
json.dump(styles, outfile, indent=2)
|
cli/generate.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"paths":
|
3 |
+
{
|
4 |
+
"root": "/mnt/c/Users/mandi/OneDrive/Generative/Generate",
|
5 |
+
"generate": "image",
|
6 |
+
"upscale": "upscale",
|
7 |
+
"grid": "grid"
|
8 |
+
},
|
9 |
+
"generate":
|
10 |
+
{
|
11 |
+
"restore_faces": true,
|
12 |
+
"prompt": "dynamic",
|
13 |
+
"negative_prompt": "foggy, blurry, blurred, duplicate, ugly, mutilated, mutation, mutated, out of frame, bad anatomy, disfigured, deformed, censored, low res, watermark, text, poorly drawn face, signature",
|
14 |
+
"steps": 30,
|
15 |
+
"batch_size": 2,
|
16 |
+
"n_iter": 1,
|
17 |
+
"seed": -1,
|
18 |
+
"sampler_name": "DPM2 Karras",
|
19 |
+
"cfg_scale": 6,
|
20 |
+
"width": 512,
|
21 |
+
"height": 512
|
22 |
+
},
|
23 |
+
"upscale":
|
24 |
+
{
|
25 |
+
"upscaler_1": "SwinIR_4x",
|
26 |
+
"upscaler_2": "None",
|
27 |
+
"upscale_first": false,
|
28 |
+
"upscaling_resize": 0,
|
29 |
+
"gfpgan_visibility": 0,
|
30 |
+
"codeformer_visibility": 0,
|
31 |
+
"codeformer_weight": 0.5
|
32 |
+
},
|
33 |
+
"options":
|
34 |
+
{
|
35 |
+
"sd_model_checkpoint": "sd-v15-runwayml",
|
36 |
+
"sd_vae": "vae-ft-mse-840000-ema-pruned.ckpt"
|
37 |
+
}
|
38 |
+
}
|
cli/generate.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# pylint: disable=no-member
|
3 |
+
"""generate batches of images from prompts and upscale them
|
4 |
+
|
5 |
+
params: run with `--help`
|
6 |
+
|
7 |
+
default workflow runs infinite loop and prints stats when interrupted:
|
8 |
+
1. choose random scheduler lookup all available and pick one
|
9 |
+
2. generate dynamic prompt based on styles, embeddings, places, artists, suffixes
|
10 |
+
3. beautify prompt
|
11 |
+
4. generate 3x3 images
|
12 |
+
5. create image grid
|
13 |
+
6. upscale images with face restoration
|
14 |
+
"""
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import asyncio
|
18 |
+
import base64
|
19 |
+
import io
|
20 |
+
import json
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import pathlib
|
25 |
+
import secrets
|
26 |
+
import time
|
27 |
+
import sys
|
28 |
+
import importlib
|
29 |
+
|
30 |
+
from random import randrange
|
31 |
+
from PIL import Image
|
32 |
+
from PIL.ExifTags import TAGS
|
33 |
+
from PIL.TiffImagePlugin import ImageFileDirectory_v2
|
34 |
+
|
35 |
+
from sdapi import close, get, interrupt, post, session
|
36 |
+
from util import Map, log, safestring
|
37 |
+
|
38 |
+
|
39 |
+
sd = {}
|
40 |
+
random = {}
|
41 |
+
stats = Map({ 'images': 0, 'wall': 0, 'generate': 0, 'upscale': 0 })
|
42 |
+
avg = {}
|
43 |
+
|
44 |
+
|
45 |
+
def grid(data):
|
46 |
+
if len(data.image) > 1:
|
47 |
+
w, h = data.image[0].size
|
48 |
+
rows = round(math.sqrt(len(data.image)))
|
49 |
+
cols = math.ceil(len(data.image) / rows)
|
50 |
+
image = Image.new('RGB', size = (cols * w, rows * h), color = 'black')
|
51 |
+
for i, img in enumerate(data.image):
|
52 |
+
image.paste(img, box=(i % cols * w, i // cols * h))
|
53 |
+
short = data.info.prompt[:min(len(data.info.prompt), 96)] # limit prompt part of filename to 96 chars
|
54 |
+
name = '{seed:0>9} {short}'.format(short = short, seed = data.info.all_seeds[0]) # pylint: disable=consider-using-f-string
|
55 |
+
name = safestring(name) + '.jpg'
|
56 |
+
f = os.path.join(sd.paths.root, sd.paths.grid, name)
|
57 |
+
log.info({ 'grid': { 'name': f, 'size': image.size, 'images': len(data.image) } })
|
58 |
+
image.save(f, 'JPEG', exif = exif(data.info, None, 'grid'), optimize = True, quality = 70)
|
59 |
+
return image
|
60 |
+
return data.image
|
61 |
+
|
62 |
+
|
63 |
+
def exif(info, i = None, op = 'generate'):
|
64 |
+
seed = [info.all_seeds[i]] if len(info.all_seeds) > 0 and i is not None else info.all_seeds # always returns list
|
65 |
+
seed = ', '.join([str(x) for x in seed]) # int list to str list to single str
|
66 |
+
template = '{prompt} | negative {negative_prompt} | seed {s} | steps {steps} | cfgscale {cfg_scale} | sampler {sampler_name} | batch {batch_size} | timestamp {job_timestamp} | model {model} | vae {vae}'.format(s = seed, model = sd.options['sd_model_checkpoint'], vae = sd.options['sd_vae'], **info) # pylint: disable=consider-using-f-string
|
67 |
+
if op == 'upscale':
|
68 |
+
template += ' | faces gfpgan' if sd.upscale.gfpgan_visibility > 0 else ''
|
69 |
+
template += ' | faces codeformer' if sd.upscale.codeformer_visibility > 0 else ''
|
70 |
+
template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_1) if sd.upscale.upscaler_1 != 'None' else '' # pylint: disable=consider-using-f-string
|
71 |
+
template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_2) if sd.upscale.upscaler_2 != 'None' else '' # pylint: disable=consider-using-f-string
|
72 |
+
if op == 'grid':
|
73 |
+
template += ' | grid {num}'.format(num = sd.generate.batch_size * sd.generate.n_iter) # pylint: disable=consider-using-f-string
|
74 |
+
ifd = ImageFileDirectory_v2()
|
75 |
+
exif_stream = io.BytesIO()
|
76 |
+
_TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
|
77 |
+
ifd[_TAGS['ImageDescription']] = template
|
78 |
+
ifd.save(exif_stream)
|
79 |
+
val = b'Exif\x00\x00' + exif_stream.getvalue()
|
80 |
+
return val
|
81 |
+
|
82 |
+
|
83 |
+
def randomize(lst):
|
84 |
+
if len(lst) > 0:
|
85 |
+
return secrets.choice(lst)
|
86 |
+
else:
|
87 |
+
return ''
|
88 |
+
|
89 |
+
|
90 |
+
def prompt(params): # generate dynamic prompt or use one if provided
|
91 |
+
sd.generate.prompt = params.prompt if params.prompt != 'dynamic' else randomize(random.prompts)
|
92 |
+
sd.generate.negative_prompt = params.negative if params.negative != 'dynamic' else randomize(random.negative)
|
93 |
+
embedding = params.embedding if params.embedding != 'random' else randomize(random.embeddings)
|
94 |
+
sd.generate.prompt = sd.generate.prompt.replace('<embedding>', embedding)
|
95 |
+
artist = params.artist if params.artist != 'random' else randomize(random.artists)
|
96 |
+
sd.generate.prompt = sd.generate.prompt.replace('<artist>', artist)
|
97 |
+
style = params.style if params.style != 'random' else randomize(random.styles)
|
98 |
+
sd.generate.prompt = sd.generate.prompt.replace('<style>', style)
|
99 |
+
suffix = params.suffix if params.suffix != 'random' else randomize(random.suffixes)
|
100 |
+
sd.generate.prompt = sd.generate.prompt.replace('<suffix>', suffix)
|
101 |
+
place = params.suffix if params.suffix != 'random' else randomize(random.places)
|
102 |
+
sd.generate.prompt = sd.generate.prompt.replace('<place>', place)
|
103 |
+
if params.prompts or params.debug:
|
104 |
+
log.info({ 'random initializers': random })
|
105 |
+
if params.prompt == 'dynamic':
|
106 |
+
log.info({ 'dynamic prompt': sd.generate.prompt })
|
107 |
+
return sd.generate.prompt
|
108 |
+
|
109 |
+
|
110 |
+
def sampler(params, options): # find sampler
|
111 |
+
if params.sampler == 'random':
|
112 |
+
sd.generate.sampler_name = randomize(options.samplers)
|
113 |
+
log.info({ 'random sampler': sd.generate.sampler_name })
|
114 |
+
else:
|
115 |
+
found = [i for i in options.samplers if i.startswith(params.sampler)]
|
116 |
+
if len(found) == 0:
|
117 |
+
log.error({ 'sampler error': sd.generate.sampler_name, 'available': options.samplers})
|
118 |
+
exit()
|
119 |
+
sd.generate.sampler_name = found[0]
|
120 |
+
return sd.generate.sampler_name
|
121 |
+
|
122 |
+
|
123 |
+
async def generate(prompt = None, options = None, quiet = False): # pylint: disable=redefined-outer-name
|
124 |
+
global sd # pylint: disable=global-statement
|
125 |
+
if options:
|
126 |
+
sd = Map(options)
|
127 |
+
if prompt is not None:
|
128 |
+
sd.generate.prompt = prompt
|
129 |
+
if not quiet:
|
130 |
+
log.info({ 'generate': sd.generate })
|
131 |
+
if sd.get('options', None) is None:
|
132 |
+
sd['options'] = await get('/sdapi/v1/options')
|
133 |
+
names = []
|
134 |
+
b64s = []
|
135 |
+
images = []
|
136 |
+
info = Map({})
|
137 |
+
data = await post('/sdapi/v1/txt2img', sd.generate)
|
138 |
+
if 'error' in data:
|
139 |
+
log.error({ 'generate': data['error'], 'reason': data['reason'] })
|
140 |
+
return Map({})
|
141 |
+
info = Map(json.loads(data['info']))
|
142 |
+
log.debug({ 'info': info })
|
143 |
+
images = data['images']
|
144 |
+
short = info.prompt[:min(len(info.prompt), 96)] # limit prompt part of filename to 64 chars
|
145 |
+
for i in range(len(images)):
|
146 |
+
b64s.append(images[i])
|
147 |
+
images[i] = Image.open(io.BytesIO(base64.b64decode(images[i].split(',',1)[0])))
|
148 |
+
name = '{seed:0>9} {short}'.format(short = short, seed = info.all_seeds[i]) # pylint: disable=consider-using-f-string
|
149 |
+
name = safestring(name) + '.jpg'
|
150 |
+
f = os.path.join(sd.paths.root, sd.paths.generate, name)
|
151 |
+
names.append(f)
|
152 |
+
if not quiet:
|
153 |
+
log.info({ 'image': { 'name': f, 'size': images[i].size } })
|
154 |
+
images[i].save(f, 'JPEG', exif = exif(info, i), optimize = True, quality = 70)
|
155 |
+
return Map({ 'name': names, 'image': images, 'b64': b64s, 'info': info })
|
156 |
+
|
157 |
+
|
158 |
+
async def upscale(data):
|
159 |
+
data.upscaled = []
|
160 |
+
if sd.upscale.upscaling_resize <=1:
|
161 |
+
return data
|
162 |
+
sd.upscale.image = ''
|
163 |
+
log.info({ 'upscale': sd.upscale })
|
164 |
+
for i in range(len(data.image)):
|
165 |
+
f = data.name[i].replace(sd.paths.generate, sd.paths.upscale)
|
166 |
+
sd.upscale.image = data.b64[i]
|
167 |
+
res = await post('/sdapi/v1/extra-single-image', sd.upscale)
|
168 |
+
image = Image.open(io.BytesIO(base64.b64decode(res['image'].split(',',1)[0])))
|
169 |
+
data.upscaled.append(image)
|
170 |
+
log.info({ 'image': { 'name': f, 'size': image.size } })
|
171 |
+
image.save(f, 'JPEG', exif = exif(data.info, i, 'upscale'), optimize = True, quality = 70)
|
172 |
+
return data
|
173 |
+
|
174 |
+
|
175 |
+
async def init():
|
176 |
+
'''
|
177 |
+
import torch
|
178 |
+
log.info({ 'torch': torch.__version__, 'available': torch.cuda.is_available() })
|
179 |
+
current_device = torch.cuda.current_device()
|
180 |
+
mem_free, mem_total = torch.cuda.mem_get_info()
|
181 |
+
log.info({ 'cuda': torch.version.cuda, 'available': torch.cuda.is_available(), 'arch': torch.cuda.get_arch_list(), 'device': torch.cuda.get_device_name(current_device), 'memory': { 'free': round(mem_free / 1024 / 1024), 'total': (mem_total / 1024 / 1024) } })
|
182 |
+
'''
|
183 |
+
options = Map({})
|
184 |
+
options.flags = await get('/sdapi/v1/cmd-flags')
|
185 |
+
log.debug({ 'flags': options.flags })
|
186 |
+
data = await get('/sdapi/v1/sd-models')
|
187 |
+
options.models = [obj['title'] for obj in data]
|
188 |
+
log.debug({ 'registered models': options.models })
|
189 |
+
found = sd.options.sd_model_checkpoint if sd.options.sd_model_checkpoint in options.models else None
|
190 |
+
if found is None:
|
191 |
+
found = [i for i in options.models if i.startswith(sd.options.sd_model_checkpoint)]
|
192 |
+
if len(found) == 0:
|
193 |
+
log.error({ 'model error': sd.generate.sd_model_checkpoint, 'available': options.models})
|
194 |
+
exit()
|
195 |
+
sd.options.sd_model_checkpoint = found[0]
|
196 |
+
data = await get('/sdapi/v1/samplers')
|
197 |
+
options.samplers = [obj['name'] for obj in data]
|
198 |
+
log.debug({ 'registered samplers': options.samplers })
|
199 |
+
data = await get('/sdapi/v1/upscalers')
|
200 |
+
options.upscalers = [obj['name'] for obj in data]
|
201 |
+
log.debug({ 'registered upscalers': options.upscalers })
|
202 |
+
data = await get('/sdapi/v1/face-restorers')
|
203 |
+
options.restorers = [obj['name'] for obj in data]
|
204 |
+
log.debug({ 'registered face restorers': options.restorers })
|
205 |
+
await interrupt()
|
206 |
+
await post('/sdapi/v1/options', sd.options)
|
207 |
+
options.options = await get('/sdapi/v1/options')
|
208 |
+
log.info({ 'target models': { 'diffuser': options.options['sd_model_checkpoint'], 'vae': options.options['sd_vae'] } })
|
209 |
+
log.info({ 'paths': sd.paths })
|
210 |
+
options.queue = await get('/queue/status')
|
211 |
+
log.info({ 'queue': options.queue })
|
212 |
+
pathlib.Path(sd.paths.root).mkdir(parents = True, exist_ok = True)
|
213 |
+
pathlib.Path(os.path.join(sd.paths.root, sd.paths.generate)).mkdir(parents = True, exist_ok = True)
|
214 |
+
pathlib.Path(os.path.join(sd.paths.root, sd.paths.upscale)).mkdir(parents = True, exist_ok = True)
|
215 |
+
pathlib.Path(os.path.join(sd.paths.root, sd.paths.grid)).mkdir(parents = True, exist_ok = True)
|
216 |
+
return options
|
217 |
+
|
218 |
+
|
219 |
+
def args(): # parse cmd arguments
|
220 |
+
global sd # pylint: disable=global-statement
|
221 |
+
global random # pylint: disable=global-statement
|
222 |
+
parser = argparse.ArgumentParser(description = 'sd pipeline')
|
223 |
+
parser.add_argument('--config', type = str, default = 'generate.json', required = False, help = 'configuration file')
|
224 |
+
parser.add_argument('--random', type = str, default = 'random.json', required = False, help = 'prompt file with randomized sections')
|
225 |
+
parser.add_argument('--max', type = int, default = 1, required = False, help = 'maximum number of generated images')
|
226 |
+
parser.add_argument('--prompt', type = str, default = 'dynamic', required = False, help = 'prompt')
|
227 |
+
parser.add_argument('--negative', type = str, default = 'dynamic', required = False, help = 'negative prompt')
|
228 |
+
parser.add_argument('--artist', type = str, default = 'random', required = False, help = 'artist style, used to guide dynamic prompt when prompt is not provided')
|
229 |
+
parser.add_argument('--embedding', type = str, default = 'random', required = False, help = 'use embedding, used to guide dynamic prompt when prompt is not provided')
|
230 |
+
parser.add_argument('--style', type = str, default = 'random', required = False, help = 'image style, used to guide dynamic prompt when prompt is not provided')
|
231 |
+
parser.add_argument('--suffix', type = str, default = 'random', required = False, help = 'style suffix, used to guide dynamic prompt when prompt is not provided')
|
232 |
+
parser.add_argument('--place', type = str, default = 'random', required = False, help = 'place locator, used to guide dynamic prompt when prompt is not provided')
|
233 |
+
parser.add_argument('--faces', default = False, action='store_true', help = 'restore faces during upscaling')
|
234 |
+
parser.add_argument('--steps', type = int, default = 0, required = False, help = 'number of steps')
|
235 |
+
parser.add_argument('--batch', type = int, default = 0, required = False, help = 'batch size, limited by gpu vram')
|
236 |
+
parser.add_argument('--n', type = int, default = 0, required = False, help = 'number of iterations')
|
237 |
+
parser.add_argument('--cfg', type = int, default = 0, required = False, help = 'classifier free guidance scale')
|
238 |
+
parser.add_argument('--sampler', type = str, default = 'random', required = False, help = 'sampler')
|
239 |
+
parser.add_argument('--seed', type = int, default = 0, required = False, help = 'seed, default is random')
|
240 |
+
parser.add_argument('--upscale', type = int, default = 0, required = False, help = 'upscale factor, disabled if 0')
|
241 |
+
parser.add_argument('--model', type = str, default = '', required = False, help = 'diffusion model')
|
242 |
+
parser.add_argument('--vae', type = str, default = '', required = False, help = 'vae model')
|
243 |
+
parser.add_argument('--path', type = str, default = '', required = False, help = 'output path')
|
244 |
+
parser.add_argument('--width', type = int, default = 0, required = False, help = 'width')
|
245 |
+
parser.add_argument('--height', type = int, default = 0, required = False, help = 'height')
|
246 |
+
parser.add_argument('--beautify', default = False, action='store_true', help = 'beautify prompt')
|
247 |
+
parser.add_argument('--prompts', default = False, action='store_true', help = 'print dynamic prompt templates')
|
248 |
+
parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
|
249 |
+
params = parser.parse_args()
|
250 |
+
if params.debug:
|
251 |
+
log.setLevel(logging.DEBUG)
|
252 |
+
log.debug({ 'debug': True })
|
253 |
+
log.debug({ 'args': params.__dict__ })
|
254 |
+
home = pathlib.Path(sys.argv[0]).parent
|
255 |
+
if os.path.isfile(params.config):
|
256 |
+
try:
|
257 |
+
with open(params.config, 'r', encoding='utf-8') as f:
|
258 |
+
data = json.load(f)
|
259 |
+
sd = Map(data)
|
260 |
+
log.debug({ 'config': sd })
|
261 |
+
except Exception as e:
|
262 |
+
log.error({ 'config error': params.config, 'exception': e })
|
263 |
+
exit()
|
264 |
+
elif os.path.isfile(os.path.join(home, params.config)):
|
265 |
+
try:
|
266 |
+
with open(os.path.join(home, params.config), 'r', encoding='utf-8') as f:
|
267 |
+
data = json.load(f)
|
268 |
+
sd = Map(data)
|
269 |
+
log.debug({ 'config': sd })
|
270 |
+
except Exception as e:
|
271 |
+
log.error({ 'config error': params.config, 'exception': e })
|
272 |
+
exit()
|
273 |
+
else:
|
274 |
+
log.error({ 'config file not found': params.config})
|
275 |
+
exit()
|
276 |
+
if params.prompt == 'dynamic':
|
277 |
+
log.info({ 'prompt template': params.random })
|
278 |
+
if os.path.isfile(params.random):
|
279 |
+
try:
|
280 |
+
with open(params.random, 'r', encoding='utf-8') as f:
|
281 |
+
data = json.load(f)
|
282 |
+
random = Map(data)
|
283 |
+
log.debug({ 'random template': sd })
|
284 |
+
except Exception:
|
285 |
+
log.error({ 'random template error': params.random})
|
286 |
+
exit()
|
287 |
+
elif os.path.isfile(os.path.join(home, params.random)):
|
288 |
+
try:
|
289 |
+
with open(os.path.join(home, params.random), 'r', encoding='utf-8') as f:
|
290 |
+
data = json.load(f)
|
291 |
+
random = Map(data)
|
292 |
+
log.debug({ 'random template': sd })
|
293 |
+
except Exception:
|
294 |
+
log.error({ 'random template error': params.random})
|
295 |
+
exit()
|
296 |
+
else:
|
297 |
+
log.error({ 'random template file not found': params.random})
|
298 |
+
exit()
|
299 |
+
_dynamic = prompt(params)
|
300 |
+
|
301 |
+
sd.paths.root = params.path if params.path != '' else sd.paths.root
|
302 |
+
sd.generate.restore_faces = params.faces if params.faces is not None else sd.generate.restore_faces
|
303 |
+
sd.generate.seed = params.seed if params.seed > 0 else sd.generate.seed
|
304 |
+
sd.generate.sampler_name = params.sampler if params.sampler != 'random' else sd.generate.sampler_name
|
305 |
+
sd.generate.batch_size = params.batch if params.batch > 0 else sd.generate.batch_size
|
306 |
+
sd.generate.cfg_scale = params.cfg if params.cfg > 0 else sd.generate.cfg_scale
|
307 |
+
sd.generate.n_iter = params.n if params.n > 0 else sd.generate.n_iter
|
308 |
+
sd.generate.width = params.width if params.width > 0 else sd.generate.width
|
309 |
+
sd.generate.height = params.height if params.height > 0 else sd.generate.height
|
310 |
+
sd.generate.steps = params.steps if params.steps > 0 else sd.generate.steps
|
311 |
+
sd.upscale.upscaling_resize = params.upscale if params.upscale > 0 else sd.upscale.upscaling_resize
|
312 |
+
sd.upscale.codeformer_visibility = 1 if params.faces else sd.upscale.codeformer_visibility
|
313 |
+
sd.options.sd_vae = params.vae if params.vae != '' else sd.options.sd_vae
|
314 |
+
sd.options.sd_model_checkpoint = params.model if params.model != '' else sd.options.sd_model_checkpoint
|
315 |
+
sd.upscale.upscaler_1 = 'SwinIR_4x' if params.upscale > 1 else sd.upscale.upscaler_1
|
316 |
+
if sd.generate.cfg_scale == 0:
|
317 |
+
sd.generate.cfg_scale = randrange(5, 10)
|
318 |
+
return params
|
319 |
+
|
320 |
+
|
321 |
+
async def main():
|
322 |
+
params = args()
|
323 |
+
sess = await session()
|
324 |
+
if sess is None:
|
325 |
+
await close()
|
326 |
+
exit()
|
327 |
+
options = await init()
|
328 |
+
iteration = 0
|
329 |
+
while True:
|
330 |
+
iteration += 1
|
331 |
+
log.info('')
|
332 |
+
log.info({ 'iteration': iteration, 'batch': sd.generate.batch_size, 'n': sd.generate.n_iter, 'total': sd.generate.n_iter * sd.generate.batch_size })
|
333 |
+
dynamic = prompt(params)
|
334 |
+
if params.beautify:
|
335 |
+
try:
|
336 |
+
promptist = importlib.import_module('modules.promptist')
|
337 |
+
sd.generate.prompt = promptist.beautify(dynamic)
|
338 |
+
except Exception as e:
|
339 |
+
log.error({ 'beautify': e })
|
340 |
+
scheduler = sampler(params, options)
|
341 |
+
t0 = time.perf_counter()
|
342 |
+
data = await generate() # generate returns list of images
|
343 |
+
if 'image' not in data:
|
344 |
+
break
|
345 |
+
stats.images += len(data.image)
|
346 |
+
t1 = time.perf_counter()
|
347 |
+
if len(data.image) > 0:
|
348 |
+
avg[scheduler] = (t1 - t0) / len(data.image)
|
349 |
+
stats.generate += t1 - t0
|
350 |
+
_image = grid(data)
|
351 |
+
data = await upscale(data)
|
352 |
+
t2 = time.perf_counter()
|
353 |
+
stats.upscale += t2 - t1
|
354 |
+
stats.wall += t2 - t0
|
355 |
+
its = sd.generate.steps / ((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
|
356 |
+
avg_time = round((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
|
357 |
+
log.info({ 'time' : { 'wall': round(t1 - t0), 'average': avg_time, 'upscale': round(t2 - t1), 'its': round(its, 2) } })
|
358 |
+
log.info({ 'generated': stats.images, 'max': params.max, 'progress': round(100 * stats.images / params.max, 1) })
|
359 |
+
if params.max != 0 and stats.images >= params.max:
|
360 |
+
break
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == '__main__':
|
364 |
+
try:
|
365 |
+
asyncio.run(main())
|
366 |
+
except KeyboardInterrupt:
|
367 |
+
asyncio.run(interrupt())
|
368 |
+
asyncio.run(close())
|
369 |
+
log.info({ 'interrupt': True })
|
370 |
+
finally:
|
371 |
+
log.info({ 'sampler performance': avg })
|
372 |
+
log.info({ 'stats' : stats })
|
373 |
+
asyncio.run(close())
|
cli/hf-convert.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import diffusers
|
8 |
+
import safetensors
|
9 |
+
import safetensors.torch as sf
|
10 |
+
|
11 |
+
log = logging.getLogger("sd")
|
12 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s | %(message)s')
|
13 |
+
|
14 |
+
|
15 |
+
def convert(model_id, output_name):
|
16 |
+
if os.path.exists(output_name):
|
17 |
+
log.error(f'Output already exists: {output_name}')
|
18 |
+
return
|
19 |
+
pipe = diffusers.DiffusionPipeline.from_pretrained(model_id)
|
20 |
+
metadata = { 'model_id': model_id }
|
21 |
+
model = {}
|
22 |
+
model['state_dict'] = vars(pipe)['_internal_dict']
|
23 |
+
for k in model['state_dict'].keys():
|
24 |
+
# print(k, getattr(pipe, k))
|
25 |
+
model[k] = getattr(pipe, k)
|
26 |
+
sf.save_model(model, output_name, metadata=metadata)
|
27 |
+
# log.info(f'Saved model: {output_name}')
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
sys.argv.pop(0)
|
31 |
+
if len(sys.argv) < 2:
|
32 |
+
log.info('Usage: hf-convert.py <model_id> <output_name>')
|
33 |
+
sys.exit(1)
|
34 |
+
log.debug(f'Packages: torch={torch.__version__} diffusers={diffusers.__version__} safetensors={safetensors.__version__}')
|
35 |
+
convert(sys.argv[0], sys.argv[1])
|
cli/hf-search.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import huggingface_hub as hf
|
5 |
+
from rich import print # pylint: disable=redefined-builtin
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
sys.argv.pop(0)
|
9 |
+
keyword = sys.argv[0] if len(sys.argv) > 0 else ''
|
10 |
+
hf_api = hf.HfApi()
|
11 |
+
model_filter = hf.ModelFilter(
|
12 |
+
model_name=keyword,
|
13 |
+
# task='text-to-image',
|
14 |
+
library=['diffusers'],
|
15 |
+
)
|
16 |
+
res = hf_api.list_models(filter=model_filter, full=True, limit=50, sort="downloads", direction=-1)
|
17 |
+
models = [{ 'name': m.modelId, 'downloads': m.downloads, 'mtime': m.lastModified, 'url': f'https://huggingface.co/{m.modelId}', 'pipeline': m.pipeline_tag, 'tags': m.tags } for m in res]
|
18 |
+
print(models)
|
cli/idle.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import datetime
|
6 |
+
import logging
|
7 |
+
import urllib3
|
8 |
+
import requests
|
9 |
+
|
10 |
+
class Dot(dict):
|
11 |
+
__getattr__ = dict.get
|
12 |
+
__setattr__ = dict.__setitem__
|
13 |
+
__delattr__ = dict.__delitem__
|
14 |
+
|
15 |
+
opts = Dot({
|
16 |
+
"timeout": 3600,
|
17 |
+
"frequency": 60,
|
18 |
+
"action": "sudo shutdown now",
|
19 |
+
"url": "https://127.0.0.1:7860",
|
20 |
+
"user": "",
|
21 |
+
"password": "",
|
22 |
+
})
|
23 |
+
|
24 |
+
log_format = '%(asctime)s %(levelname)s: %(message)s'
|
25 |
+
logging.basicConfig(level = logging.INFO, format = log_format)
|
26 |
+
log = logging.getLogger("sd")
|
27 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
28 |
+
status = None
|
29 |
+
|
30 |
+
def progress():
|
31 |
+
auth = requests.auth.HTTPBasicAuth(opts.user, opts.password) if opts.user is not None and len(opts.user) > 0 and opts.password is not None and len(opts.password) > 0 else None
|
32 |
+
req = requests.get(f'{opts.url}/sdapi/v1/progress?skip_current_image=true', verify=False, auth=auth, timeout=60)
|
33 |
+
if req.status_code != 200:
|
34 |
+
log.error({ 'url': req.url, 'request': req.status_code, 'reason': req.reason })
|
35 |
+
return status
|
36 |
+
else:
|
37 |
+
res = Dot(req.json())
|
38 |
+
log.debug({ 'url': req.url, 'request': req.status_code, 'result': res })
|
39 |
+
return res
|
40 |
+
|
41 |
+
log.info(f'sdnext monitor started: {opts}')
|
42 |
+
while True:
|
43 |
+
try:
|
44 |
+
status = progress()
|
45 |
+
state = status.get('state', {})
|
46 |
+
last_job = state.get('job_timestamp', None)
|
47 |
+
if last_job is None:
|
48 |
+
log.warning(f'sdnext montoring cannot get last job info: {status}')
|
49 |
+
else:
|
50 |
+
last_job = datetime.datetime.strptime(last_job, "%Y%m%d%H%M%S")
|
51 |
+
elapsed = datetime.datetime.now() - last_job
|
52 |
+
timeout = round(opts.timeout - elapsed.total_seconds())
|
53 |
+
log.info(f'sdnext: last_job={last_job} elapsed={elapsed} timeout={timeout}')
|
54 |
+
if timeout < 0:
|
55 |
+
log.warning(f'sdnext reached: timeout={opts.timeout} action={opts.action}')
|
56 |
+
os.system(opts.action)
|
57 |
+
except Exception as e:
|
58 |
+
log.error(f'sdnext monitor error: {e}')
|
59 |
+
finally:
|
60 |
+
time.sleep(opts.frequency)
|
cli/image-exif.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import io
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
import json
|
8 |
+
from PIL import Image, ExifTags, TiffImagePlugin, PngImagePlugin
|
9 |
+
from rich import print # pylint: disable=redefined-builtin
|
10 |
+
|
11 |
+
|
12 |
+
def unquote(text):
|
13 |
+
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
|
14 |
+
return text
|
15 |
+
try:
|
16 |
+
return json.loads(text)
|
17 |
+
except Exception:
|
18 |
+
return text
|
19 |
+
|
20 |
+
|
21 |
+
def parse_generation_parameters(infotext): # copied from modules.generation_parameters_copypaste
|
22 |
+
if not isinstance(infotext, str):
|
23 |
+
return {}
|
24 |
+
|
25 |
+
re_param = re.compile(r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)') # multi-word: value
|
26 |
+
re_size = re.compile(r"^(\d+)x(\d+)$") # int x int
|
27 |
+
sanitized = infotext.replace('prompt:', 'Prompt:').replace('negative prompt:', 'Negative prompt:').replace('Negative Prompt', 'Negative prompt') # cleanup everything in brackets so re_params can work
|
28 |
+
sanitized = re.sub(r'<[^>]*>', lambda match: ' ' * len(match.group()), sanitized)
|
29 |
+
sanitized = re.sub(r'\([^)]*\)', lambda match: ' ' * len(match.group()), sanitized)
|
30 |
+
sanitized = re.sub(r'\{[^}]*\}', lambda match: ' ' * len(match.group()), sanitized)
|
31 |
+
|
32 |
+
params = dict(re_param.findall(sanitized))
|
33 |
+
params = { k.strip():params[k].strip() for k in params if k.lower() not in ['hashes', 'lora', 'embeddings', 'prompt', 'negative prompt']} # remove some keys
|
34 |
+
first_param = next(iter(params)) if params else None
|
35 |
+
params_idx = sanitized.find(f'{first_param}:') if first_param else -1
|
36 |
+
negative_idx = infotext.find("Negative prompt:")
|
37 |
+
|
38 |
+
prompt = infotext[:params_idx] if negative_idx == -1 else infotext[:negative_idx] # prompt can be with or without negative prompt
|
39 |
+
negative = infotext[negative_idx:params_idx] if negative_idx >= 0 else ''
|
40 |
+
|
41 |
+
for k, v in params.copy().items(): # avoid dict-has-changed
|
42 |
+
if len(v) > 0 and v[0] == '"' and v[-1] == '"':
|
43 |
+
v = unquote(v)
|
44 |
+
m = re_size.match(v)
|
45 |
+
if v.replace('.', '', 1).isdigit():
|
46 |
+
params[k] = float(v) if '.' in v else int(v)
|
47 |
+
elif v == "True":
|
48 |
+
params[k] = True
|
49 |
+
elif v == "False":
|
50 |
+
params[k] = False
|
51 |
+
elif m is not None:
|
52 |
+
params[f"{k}-1"] = int(m.group(1))
|
53 |
+
params[f"{k}-2"] = int(m.group(2))
|
54 |
+
elif k == 'VAE' and v == 'TAESD':
|
55 |
+
params["Full quality"] = False
|
56 |
+
else:
|
57 |
+
params[k] = v
|
58 |
+
params["Prompt"] = prompt.replace('Prompt:', '').strip()
|
59 |
+
params["Negative prompt"] = negative.replace('Negative prompt:', '').strip()
|
60 |
+
return params
|
61 |
+
|
62 |
+
|
63 |
+
class Exif: # pylint: disable=single-string-used-for-slots
|
64 |
+
__slots__ = ('__dict__') # pylint: disable=superfluous-parens
|
65 |
+
def __init__(self, image = None):
|
66 |
+
super(Exif, self).__setattr__('exif', Image.Exif()) # pylint: disable=super-with-arguments
|
67 |
+
self.pnginfo = PngImagePlugin.PngInfo()
|
68 |
+
self.tags = {**dict(ExifTags.TAGS.items()), **dict(ExifTags.GPSTAGS.items())}
|
69 |
+
self.ids = {**{v: k for k, v in ExifTags.TAGS.items()}, **{v: k for k, v in ExifTags.GPSTAGS.items()}}
|
70 |
+
if image is not None:
|
71 |
+
self.load(image)
|
72 |
+
|
73 |
+
def __getattr__(self, attr):
|
74 |
+
if attr in self.__dict__:
|
75 |
+
return self.__dict__[attr]
|
76 |
+
return self.exif.get(attr, None)
|
77 |
+
|
78 |
+
def load(self, img: Image):
|
79 |
+
img.load() # exif may not be ready
|
80 |
+
exif_dict = {}
|
81 |
+
try:
|
82 |
+
exif_dict = dict(img._getexif().items()) # pylint: disable=protected-access
|
83 |
+
except Exception:
|
84 |
+
exif_dict = dict(img.info.items())
|
85 |
+
for key, val in exif_dict.items():
|
86 |
+
if isinstance(val, bytes): # decode bytestring
|
87 |
+
val = self.decode(val)
|
88 |
+
if val is not None:
|
89 |
+
if isinstance(key, str):
|
90 |
+
self.exif[key] = val
|
91 |
+
self.pnginfo.add_text(key, str(val), zip=False)
|
92 |
+
elif isinstance(key, int) and key in ExifTags.TAGS: # add known tags
|
93 |
+
if self.tags[key] in ['ExifOffset']:
|
94 |
+
continue
|
95 |
+
self.exif[self.tags[key]] = val
|
96 |
+
self.pnginfo.add_text(self.tags[key], str(val), zip=False)
|
97 |
+
# if self.tags[key] == 'UserComment': # add geninfo from UserComment
|
98 |
+
# self.geninfo = val
|
99 |
+
else:
|
100 |
+
print('metadata unknown tag:', key, val)
|
101 |
+
for key, val in self.exif.items():
|
102 |
+
if isinstance(val, bytes): # decode bytestring
|
103 |
+
self.exif[key] = self.decode(val)
|
104 |
+
|
105 |
+
def decode(self, s: bytes):
|
106 |
+
remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment
|
107 |
+
for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings
|
108 |
+
try:
|
109 |
+
s = remove_prefix(s, b'UNICODE')
|
110 |
+
s = remove_prefix(s, b'ASCII')
|
111 |
+
s = remove_prefix(s, b'\x00')
|
112 |
+
val = s.decode(encoding, errors="strict")
|
113 |
+
val = re.sub(r'[\x00-\x09]', '', val).strip() # remove remaining special characters
|
114 |
+
if len(val) == 0: # remove empty strings
|
115 |
+
val = None
|
116 |
+
return val
|
117 |
+
except Exception:
|
118 |
+
pass
|
119 |
+
return None
|
120 |
+
|
121 |
+
def parse(self):
|
122 |
+
x = self.exif.pop('parameters', None) or self.exif.pop('UserComment', None)
|
123 |
+
res = parse_generation_parameters(x)
|
124 |
+
return res
|
125 |
+
|
126 |
+
def get_bytes(self):
|
127 |
+
ifd = TiffImagePlugin.ImageFileDirectory_v2()
|
128 |
+
exif_stream = io.BytesIO()
|
129 |
+
for key, val in self.exif.items():
|
130 |
+
if key in self.ids:
|
131 |
+
ifd[self.ids[key]] = val
|
132 |
+
else:
|
133 |
+
print('metadata unknown exif tag:', key, val)
|
134 |
+
ifd.save(exif_stream)
|
135 |
+
raw = b'Exif\x00\x00' + exif_stream.getvalue()
|
136 |
+
return raw
|
137 |
+
|
138 |
+
|
139 |
+
def read_exif(filename: str):
|
140 |
+
if filename.lower().endswith('.heic'):
|
141 |
+
from pi_heif import register_heif_opener
|
142 |
+
register_heif_opener()
|
143 |
+
try:
|
144 |
+
image = Image.open(filename)
|
145 |
+
exif = Exif(image)
|
146 |
+
print('image:', filename, 'format:', image)
|
147 |
+
print('exif:', vars(exif.exif)['_data'])
|
148 |
+
print('info:', exif.parse())
|
149 |
+
except Exception as e:
|
150 |
+
print('metadata error reading:', filename, e)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == '__main__':
|
154 |
+
sys.argv.pop(0)
|
155 |
+
if len(sys.argv) == 0:
|
156 |
+
print('metadata:', 'no files specified')
|
157 |
+
for fn in sys.argv:
|
158 |
+
if os.path.isfile(fn):
|
159 |
+
read_exif(fn)
|
160 |
+
elif os.path.isdir(fn):
|
161 |
+
for root, _dirs, files in os.walk(fn):
|
162 |
+
for file in files:
|
163 |
+
read_exif(os.path.join(root, file))
|
cli/image-grid.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Create image grid
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
import math
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import filetype
|
12 |
+
from PIL import Image, ImageDraw, ImageFont
|
13 |
+
from util import log
|
14 |
+
|
15 |
+
|
16 |
+
params = None
|
17 |
+
|
18 |
+
|
19 |
+
def wrap(text: str, font: ImageFont.ImageFont, length: int):
|
20 |
+
lines = ['']
|
21 |
+
for word in text.split():
|
22 |
+
line = f'{lines[-1]} {word}'.strip()
|
23 |
+
if font.getlength(line) <= length:
|
24 |
+
lines[-1] = line
|
25 |
+
else:
|
26 |
+
lines.append(word)
|
27 |
+
return '\n'.join(lines)
|
28 |
+
|
29 |
+
|
30 |
+
def grid(images, labels = None, width = 0, height = 0, border = 0, square = False, horizontal = False, vertical = False): # pylint: disable=redefined-outer-name
|
31 |
+
if horizontal:
|
32 |
+
rows = 1
|
33 |
+
elif vertical:
|
34 |
+
rows = len(images)
|
35 |
+
elif square:
|
36 |
+
rows = round(math.sqrt(len(images)))
|
37 |
+
else:
|
38 |
+
rows = math.floor(math.sqrt(len(images)))
|
39 |
+
cols = math.ceil(len(images) / rows)
|
40 |
+
size = [0, 0]
|
41 |
+
if width == 0:
|
42 |
+
w = max([i.size[0] for i in images])
|
43 |
+
size[0] = cols * w + cols * border
|
44 |
+
else:
|
45 |
+
size[0] = width
|
46 |
+
w = round(width / cols)
|
47 |
+
if height == 0:
|
48 |
+
h = max([i.size[1] for i in images])
|
49 |
+
size[1] = rows * h + rows * border
|
50 |
+
else:
|
51 |
+
size[1] = height
|
52 |
+
h = round(height / rows)
|
53 |
+
size = tuple(size)
|
54 |
+
image = Image.new('RGB', size = size, color = 'black') # pylint: disable=redefined-outer-name
|
55 |
+
font = ImageFont.truetype('DejaVuSansMono', round(w / 40))
|
56 |
+
for i, img in enumerate(images): # pylint: disable=redefined-outer-name
|
57 |
+
x = (i % cols * w) + (i % cols * border)
|
58 |
+
y = (i // cols * h) + (i // cols * border)
|
59 |
+
img.thumbnail((w, h), Image.Resampling.HAMMING)
|
60 |
+
image.paste(img, box=(x + int(border / 2), y + int(border / 2)))
|
61 |
+
if labels is not None and len(images) == len(labels):
|
62 |
+
ctx = ImageDraw.Draw(image)
|
63 |
+
label = wrap(labels[i], font, w)
|
64 |
+
ctx.text((x + 1 + round(w / 200), y + 1 + round(w / 200)), label, font = font, fill = (0, 0, 0))
|
65 |
+
ctx.text((x, y), label, font = font, fill = (255, 255, 255))
|
66 |
+
log.info({ 'grid': { 'images': len(images), 'rows': rows, 'cols': cols, 'cell': [w, h] } })
|
67 |
+
return image
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
log.info({ 'create grid' })
|
72 |
+
parser = argparse.ArgumentParser(description='image grid utility')
|
73 |
+
parser.add_argument("--square", default = False, action='store_true', help = "create square grid")
|
74 |
+
parser.add_argument("--horizontal", default = False, action='store_true', help = "create horizontal grid")
|
75 |
+
parser.add_argument("--vertical", default = False, action='store_true', help = "create vertical grid")
|
76 |
+
parser.add_argument("--width", type = int, default = 0, required = False, help = "fixed grid width")
|
77 |
+
parser.add_argument("--height", type = int, default = 0, required = False, help = "fixed grid height")
|
78 |
+
parser.add_argument("--border", type = int, default = 0, required = False, help = "image border")
|
79 |
+
parser.add_argument('--nolabels', default = False, action='store_true', help = "do not print image labels")
|
80 |
+
parser.add_argument('--debug', default = False, action='store_true', help = "print extra debug information")
|
81 |
+
parser.add_argument('output', type = str)
|
82 |
+
parser.add_argument('input', type = str, nargs = '*')
|
83 |
+
params = parser.parse_args()
|
84 |
+
output = params.output if params.output.lower().endswith('.jpg') else params.output + '.jpg'
|
85 |
+
if params.debug:
|
86 |
+
log.setLevel(logging.DEBUG)
|
87 |
+
log.debug({ 'debug': True })
|
88 |
+
log.debug({ 'args': params.__dict__ })
|
89 |
+
images = []
|
90 |
+
labels = []
|
91 |
+
for f in params.input:
|
92 |
+
path = Path(f)
|
93 |
+
if path.is_dir():
|
94 |
+
files = [os.path.join(f, file) for file in os.listdir(f) if os.path.isfile(os.path.join(f, file))]
|
95 |
+
elif path.is_file():
|
96 |
+
files = [f]
|
97 |
+
else:
|
98 |
+
log.warning({ 'grid not a valid file/folder', f})
|
99 |
+
continue
|
100 |
+
files.sort()
|
101 |
+
for file in files:
|
102 |
+
if not filetype.is_image(file):
|
103 |
+
continue
|
104 |
+
if file.lower().endswith('.heic'):
|
105 |
+
from pi_heif import register_heif_opener
|
106 |
+
register_heif_opener()
|
107 |
+
log.debug(file)
|
108 |
+
img = Image.open(file)
|
109 |
+
# img.verify()
|
110 |
+
images.append(img)
|
111 |
+
fp = Path(file)
|
112 |
+
if not params.nolabels:
|
113 |
+
labels.append(fp.stem)
|
114 |
+
# log.info({ 'folder': path.parent, 'labels': labels })
|
115 |
+
if len(images) > 0:
|
116 |
+
image = grid(
|
117 |
+
images = images,
|
118 |
+
labels = labels,
|
119 |
+
width = params.width,
|
120 |
+
height = params.height,
|
121 |
+
border = params.border,
|
122 |
+
square = params.square,
|
123 |
+
horizontal = params.horizontal,
|
124 |
+
vertical = params.vertical)
|
125 |
+
image.save(output, 'JPEG', optimize = True, quality = 60)
|
126 |
+
log.info({ 'grid': { 'file': output, 'size': list(image.size) } })
|
127 |
+
else:
|
128 |
+
log.info({ 'grid': 'nothing to do' })
|
cli/image-interrogate.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
use clip to interrogate image(s)
|
4 |
+
"""
|
5 |
+
|
6 |
+
import io
|
7 |
+
import base64
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import asyncio
|
11 |
+
import filetype
|
12 |
+
from PIL import Image
|
13 |
+
from util import log, Map
|
14 |
+
import sdapi
|
15 |
+
|
16 |
+
|
17 |
+
stats = { 'captions': {}, 'keywords': {} }
|
18 |
+
exclude = ['a', 'in', 'on', 'out', 'at', 'the', 'and', 'with', 'next', 'to', 'it', 'for', 'of', 'into', 'that']
|
19 |
+
|
20 |
+
|
21 |
+
def decode(encoding):
|
22 |
+
if encoding.startswith("data:image/"):
|
23 |
+
encoding = encoding.split(";")[1].split(",")[1]
|
24 |
+
return Image.open(io.BytesIO(base64.b64decode(encoding)))
|
25 |
+
|
26 |
+
|
27 |
+
def encode(f):
|
28 |
+
image = Image.open(f)
|
29 |
+
exif = image.getexif()
|
30 |
+
if image.mode == 'RGBA':
|
31 |
+
image = image.convert('RGB')
|
32 |
+
with io.BytesIO() as stream:
|
33 |
+
image.save(stream, 'JPEG', exif = exif)
|
34 |
+
values = stream.getvalue()
|
35 |
+
encoded = base64.b64encode(values).decode()
|
36 |
+
return encoded
|
37 |
+
|
38 |
+
|
39 |
+
def print_summary():
|
40 |
+
captions = dict(sorted(stats['captions'].items(), key=lambda x:x[1], reverse=True))
|
41 |
+
log.info({ 'caption stats': captions })
|
42 |
+
keywords = dict(sorted(stats['keywords'].items(), key=lambda x:x[1], reverse=True))
|
43 |
+
log.info({ 'keyword stats': keywords })
|
44 |
+
|
45 |
+
|
46 |
+
async def interrogate(f):
|
47 |
+
if not filetype.is_image(f):
|
48 |
+
log.info({ 'interrogate skip': f })
|
49 |
+
return
|
50 |
+
json = Map({ 'image': encode(f) })
|
51 |
+
log.info({ 'interrogate': f })
|
52 |
+
# run clip
|
53 |
+
json.model = 'clip'
|
54 |
+
res = await sdapi.post('/sdapi/v1/interrogate', json)
|
55 |
+
caption = ""
|
56 |
+
style = ""
|
57 |
+
if 'caption' in res:
|
58 |
+
caption = res.caption
|
59 |
+
log.info({ 'interrogate caption': caption })
|
60 |
+
if ', by' in caption:
|
61 |
+
style = caption.split(', by')[1].strip()
|
62 |
+
log.info({ 'interrogate style': style })
|
63 |
+
for word in caption.split(' '):
|
64 |
+
if word not in exclude:
|
65 |
+
stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
|
66 |
+
else:
|
67 |
+
log.error({ 'interrogate clip error': res })
|
68 |
+
# run booru
|
69 |
+
json.model = 'deepdanbooru'
|
70 |
+
res = await sdapi.post('/sdapi/v1/interrogate', json)
|
71 |
+
keywords = {}
|
72 |
+
if 'caption' in res:
|
73 |
+
for term in res.caption.split(', '):
|
74 |
+
term = term.replace('(', '').replace(')', '').replace('\\', '').split(':')
|
75 |
+
if len(term) < 2:
|
76 |
+
continue
|
77 |
+
keywords[term[0]] = term[1]
|
78 |
+
keywords = dict(sorted(keywords.items(), key=lambda x:x[1], reverse=True))
|
79 |
+
for word in keywords.items():
|
80 |
+
stats['keywords'][word[0]] = stats['keywords'][word[0]] + 1 if word[0] in stats['keywords'] else 1
|
81 |
+
log.info({ 'interrogate keywords': keywords })
|
82 |
+
else:
|
83 |
+
log.error({ 'interrogate booru error': res })
|
84 |
+
return caption, keywords, style
|
85 |
+
|
86 |
+
|
87 |
+
async def main():
|
88 |
+
sys.argv.pop(0)
|
89 |
+
await sdapi.session()
|
90 |
+
if len(sys.argv) == 0:
|
91 |
+
log.error({ 'interrogate': 'no files specified' })
|
92 |
+
for arg in sys.argv:
|
93 |
+
if os.path.exists(arg):
|
94 |
+
if os.path.isfile(arg):
|
95 |
+
await interrogate(arg)
|
96 |
+
elif os.path.isdir(arg):
|
97 |
+
for root, _dirs, files in os.walk(arg):
|
98 |
+
for f in files:
|
99 |
+
_caption, _keywords, _style = await interrogate(os.path.join(root, f))
|
100 |
+
else:
|
101 |
+
log.error({ 'interrogate unknown file type': arg })
|
102 |
+
else:
|
103 |
+
log.error({ 'interrogate file missing': arg })
|
104 |
+
await sdapi.close()
|
105 |
+
print_summary()
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
asyncio.run(main())
|
cli/image-palette.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# based on <https://towardsdatascience.com/image-color-extraction-with-python-in-4-steps-8d9370d9216e>
|
3 |
+
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
import pathlib
|
7 |
+
import argparse
|
8 |
+
import importlib
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import extcolors
|
12 |
+
import filetype
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import matplotlib.patches as patches
|
15 |
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
16 |
+
from colormap import rgb2hex
|
17 |
+
from PIL import Image
|
18 |
+
from util import log
|
19 |
+
grid = importlib.import_module('image-grid').grid
|
20 |
+
|
21 |
+
|
22 |
+
def color_to_df(param):
|
23 |
+
colors_pre_list = str(param).replace('([(','').split(', (')[0:-1]
|
24 |
+
df_rgb = [i.split('), ')[0] + ')' for i in colors_pre_list]
|
25 |
+
df_percent = [i.split('), ')[1].replace(')','') for i in colors_pre_list]
|
26 |
+
#convert RGB to HEX code
|
27 |
+
df_color_up = [rgb2hex(int(i.split(", ")[0].replace("(","")),
|
28 |
+
int(i.split(", ")[1]),
|
29 |
+
int(i.split(", ")[2].replace(")",""))) for i in df_rgb]
|
30 |
+
df = pd.DataFrame(zip(df_color_up, df_percent), columns = ['c_code','occurence'])
|
31 |
+
return df
|
32 |
+
|
33 |
+
|
34 |
+
def palette(img, params, output):
|
35 |
+
size = 1024
|
36 |
+
img.thumbnail((size, size), Image.HAMMING)
|
37 |
+
|
38 |
+
#crate dataframe
|
39 |
+
colors_x = extcolors.extract_from_image(img, tolerance = params.color, limit = 13)
|
40 |
+
df_color = color_to_df(colors_x)
|
41 |
+
|
42 |
+
#annotate text
|
43 |
+
list_color = list(df_color['c_code'])
|
44 |
+
list_precent = [int(i) for i in list(df_color['occurence'])]
|
45 |
+
text_c = [c + ' ' + str(round(p * 100 / sum(list_precent), 1)) +'%' for c, p in zip(list_color, list_precent)]
|
46 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(120,60), dpi=10)
|
47 |
+
fig.set_facecolor('black')
|
48 |
+
|
49 |
+
#donut plot
|
50 |
+
wedges, _text = ax1.pie(list_precent, labels= text_c, labeldistance= 1.05, colors = list_color, textprops={'fontsize': 100, 'color':'white'})
|
51 |
+
plt.setp(wedges, width=0.3)
|
52 |
+
|
53 |
+
#add image in the center of donut plot
|
54 |
+
data = np.asarray(img)
|
55 |
+
imagebox = OffsetImage(data, zoom=2.5)
|
56 |
+
ab = AnnotationBbox(imagebox, (0, 0))
|
57 |
+
ax1.add_artist(ab)
|
58 |
+
|
59 |
+
#color palette
|
60 |
+
x_posi, y_posi, y_posi2 = 160, -260, -260
|
61 |
+
for c in list_color:
|
62 |
+
if list_color.index(c) <= 5:
|
63 |
+
y_posi += 240
|
64 |
+
rect = patches.Rectangle((x_posi, y_posi), 540, 230, facecolor = c)
|
65 |
+
ax2.add_patch(rect)
|
66 |
+
ax2.text(x = x_posi + 100, y = y_posi + 140, s = c, fontdict={'fontsize': 140}, color = 'white')
|
67 |
+
else:
|
68 |
+
y_posi2 += 240
|
69 |
+
rect = patches.Rectangle((x_posi + 600, y_posi2), 540, 230, facecolor = c)
|
70 |
+
ax2.add_artist(rect)
|
71 |
+
ax2.text(x = x_posi + 700, y = y_posi2 + 140, s = c, fontdict={'fontsize': 140}, color = 'white')
|
72 |
+
|
73 |
+
# add background to force layout
|
74 |
+
fig.set_facecolor('black')
|
75 |
+
ax2.axis('off')
|
76 |
+
tmp = Image.new('RGB', (2000, 1400), (0, 0, 0))
|
77 |
+
plt.imshow(tmp)
|
78 |
+
plt.tight_layout(rect = (-0.08, -0.2, 1.18, 1.05))
|
79 |
+
|
80 |
+
# save image
|
81 |
+
if output is not None:
|
82 |
+
buf = io.BytesIO()
|
83 |
+
plt.savefig(buf, format='png')
|
84 |
+
pltimg = Image.open(buf)
|
85 |
+
pltimg = pltimg.convert('RGB')
|
86 |
+
pltimg.save(output)
|
87 |
+
buf.close()
|
88 |
+
log.info({ 'palette created': output })
|
89 |
+
|
90 |
+
plt.close()
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
parser = argparse.ArgumentParser(description = 'extract image color palette')
|
95 |
+
parser.add_argument('--color', type=int, default=20, help="color tolerance threshdold")
|
96 |
+
parser.add_argument('--output', type=str, required=False, default='', help='folder to store images')
|
97 |
+
parser.add_argument('--suffix', type=str, required=False, default='pallete', help='add suffix to image name')
|
98 |
+
parser.add_argument('--grid', default=False, action='store_true', help = "create grid of images before processing")
|
99 |
+
parser.add_argument('input', type=str, nargs='*')
|
100 |
+
args = parser.parse_args()
|
101 |
+
log.info({ 'palette args': vars(args) })
|
102 |
+
if args.output != '':
|
103 |
+
pathlib.Path(args.output).mkdir(parents = True, exist_ok = True)
|
104 |
+
if not args.grid:
|
105 |
+
for arg in args.input:
|
106 |
+
if os.path.isfile(arg) and filetype.is_image(arg):
|
107 |
+
image = Image.open(arg)
|
108 |
+
fn = os.path.join(args.output, pathlib.Path(arg).stem + '-' + args.suffix + '.jpg')
|
109 |
+
palette(image, args, fn)
|
110 |
+
elif os.path.isdir(arg):
|
111 |
+
for root, _dirs, files in os.walk(arg):
|
112 |
+
for f in files:
|
113 |
+
if filetype.is_image(os.path.join(root, f)):
|
114 |
+
image = Image.open(os.path.join(root, f))
|
115 |
+
fn = os.path.join(args.output, pathlib.Path(f).stem + '-' + args.suffix + '.jpg')
|
116 |
+
palette(image, args, fn)
|
117 |
+
else:
|
118 |
+
images = []
|
119 |
+
for arg in args.input:
|
120 |
+
if os.path.isfile(arg) and filetype.is_image(arg):
|
121 |
+
images.append(Image.open(arg))
|
122 |
+
elif os.path.isdir(arg):
|
123 |
+
for root, _dirs, files in os.walk(arg):
|
124 |
+
for f in files:
|
125 |
+
if filetype.is_image(os.path.join(root, f)):
|
126 |
+
images.append(Image.open(os.path.join(root, f)))
|
127 |
+
image = grid(images)
|
128 |
+
fn = os.path.join(args.output, args.suffix + '.jpg')
|
129 |
+
palette(image, args, fn)
|
cli/image-watermark.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import pathlib
|
5 |
+
import argparse
|
6 |
+
import filetype
|
7 |
+
import numpy as np
|
8 |
+
from imwatermark import WatermarkEncoder, WatermarkDecoder
|
9 |
+
from PIL import Image
|
10 |
+
from PIL.ExifTags import TAGS
|
11 |
+
from PIL.TiffImagePlugin import ImageFileDirectory_v2
|
12 |
+
from util import log, Map
|
13 |
+
import piexif
|
14 |
+
import piexif.helper
|
15 |
+
|
16 |
+
|
17 |
+
options = Map({ 'method': 'dwtDctSvd', 'type': 'bytes' })
|
18 |
+
|
19 |
+
|
20 |
+
def get_exif(image):
|
21 |
+
# using piexif
|
22 |
+
res1 = {}
|
23 |
+
try:
|
24 |
+
exif = piexif.load(image.info["exif"])
|
25 |
+
exif = exif.get("Exif", {})
|
26 |
+
for k, v in exif.items():
|
27 |
+
key = list(vars(piexif.ExifIFD).keys())[list(vars(piexif.ExifIFD).values()).index(k)]
|
28 |
+
res1[key] = piexif.helper.UserComment.load(v)
|
29 |
+
except Exception:
|
30 |
+
pass
|
31 |
+
# using pillow
|
32 |
+
res2 = {}
|
33 |
+
try:
|
34 |
+
res2 = { TAGS[k]: v for k, v in image.getexif().items() if k in TAGS }
|
35 |
+
except Exception:
|
36 |
+
pass
|
37 |
+
return {**res1, **res2}
|
38 |
+
|
39 |
+
|
40 |
+
def set_exif(d: dict):
|
41 |
+
ifd = ImageFileDirectory_v2()
|
42 |
+
_TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
|
43 |
+
for k, v in d.items():
|
44 |
+
ifd[_TAGS[k]] = v
|
45 |
+
exif_stream = io.BytesIO()
|
46 |
+
ifd.save(exif_stream)
|
47 |
+
encoded = b'Exif\x00\x00' + exif_stream.getvalue()
|
48 |
+
return encoded
|
49 |
+
|
50 |
+
|
51 |
+
def get_watermark(image, params):
|
52 |
+
data = np.asarray(image)
|
53 |
+
decoder = WatermarkDecoder(options.type, params.length)
|
54 |
+
decoded = decoder.decode(data, options.method)
|
55 |
+
wm = decoded.decode(encoding='ascii', errors='ignore')
|
56 |
+
return wm
|
57 |
+
|
58 |
+
|
59 |
+
def set_watermark(image, params):
|
60 |
+
data = np.asarray(image)
|
61 |
+
encoder = WatermarkEncoder()
|
62 |
+
length = params.length // 8
|
63 |
+
text = f"{params.wm:<{length}}"[:length]
|
64 |
+
bytearr = text.encode(encoding='ascii', errors='ignore')
|
65 |
+
encoder.set_watermark(options.type, bytearr)
|
66 |
+
encoded = encoder.encode(data, options.method)
|
67 |
+
image = Image.fromarray(encoded)
|
68 |
+
return image
|
69 |
+
|
70 |
+
|
71 |
+
def watermark(params, file):
|
72 |
+
if not os.path.exists(file):
|
73 |
+
log.error({ 'watermark': 'file not found' })
|
74 |
+
return
|
75 |
+
if not filetype.is_image(file):
|
76 |
+
log.error({ 'watermark': 'file is not an image' })
|
77 |
+
return
|
78 |
+
image = Image.open(file)
|
79 |
+
if image.width * image.height < 256 * 256:
|
80 |
+
log.error({ 'watermark': 'image too small' })
|
81 |
+
return
|
82 |
+
|
83 |
+
exif = get_exif(image)
|
84 |
+
|
85 |
+
if params.command == 'read':
|
86 |
+
fn = params.input
|
87 |
+
wm = get_watermark(image, params)
|
88 |
+
|
89 |
+
elif params.command == 'write':
|
90 |
+
metadata = b'' if params.strip else set_exif(exif)
|
91 |
+
if params.output != '':
|
92 |
+
pathlib.Path(params.output).mkdir(parents = True, exist_ok = True)
|
93 |
+
image=set_watermark(image, params)
|
94 |
+
fn = os.path.join(params.output, file)
|
95 |
+
image.save(fn, exif=metadata)
|
96 |
+
|
97 |
+
if params.verify:
|
98 |
+
image = Image.open(fn)
|
99 |
+
data = np.asarray(image)
|
100 |
+
decoder = WatermarkDecoder(options.type, params.length)
|
101 |
+
decoded = decoder.decode(data, options.method)
|
102 |
+
wm = decoded.decode(encoding='ascii', errors='ignore')
|
103 |
+
else:
|
104 |
+
wm = params.wm
|
105 |
+
|
106 |
+
log.info({ 'file': fn })
|
107 |
+
log.info({ 'resolution': f'{image.width}x{image.height}' })
|
108 |
+
log.info({ 'watermark': wm })
|
109 |
+
log.info({ 'exif': None if params.strip else exif })
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
parser = argparse.ArgumentParser(description = 'image watermarking')
|
114 |
+
parser.add_argument('command', choices = ['read', 'write'])
|
115 |
+
parser.add_argument('--wm', type=str, required=False, default='sdnext', help='watermark string')
|
116 |
+
parser.add_argument('--strip', default=False, action='store_true', help = "strip existing exif data")
|
117 |
+
parser.add_argument('--verify', default=False, action='store_true', help = "verify watermark during write")
|
118 |
+
parser.add_argument('--length', type=int, default=32, help="watermark length in bits")
|
119 |
+
parser.add_argument('--output', type=str, required=False, default='', help='folder to store images, default is overwrite in-place')
|
120 |
+
parser.add_argument('input', type=str, nargs='*')
|
121 |
+
args = parser.parse_args()
|
122 |
+
# log.info({ 'watermark args': vars(args), 'options': options })
|
123 |
+
for arg in args.input:
|
124 |
+
if os.path.isfile(arg):
|
125 |
+
watermark(args, arg)
|
126 |
+
elif os.path.isdir(arg):
|
127 |
+
for root, _dirs, files in os.walk(arg):
|
128 |
+
for f in files:
|
129 |
+
watermark(args, os.path.join(root, f))
|
cli/install-sf.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
|
6 |
+
torch_supported = ['211', '212','220','221']
|
7 |
+
cuda_supported = ['cu118', 'cu121']
|
8 |
+
python_supported = ['39', '310', '311']
|
9 |
+
repo_url = 'https://github.com/chengzeyi/stable-fast'
|
10 |
+
api_url = 'https://api.github.com/repos/chengzeyi/stable-fast/releases/tags/nightly'
|
11 |
+
path_url = '/releases/download/nightly'
|
12 |
+
|
13 |
+
|
14 |
+
def install_pip(arg: str):
|
15 |
+
import subprocess
|
16 |
+
cmd = f'"{sys.executable}" -m pip install -U {arg}'
|
17 |
+
print(f'Running: {cmd}')
|
18 |
+
result = subprocess.run(cmd, shell=True, check=False, env=os.environ)
|
19 |
+
return result.returncode == 0
|
20 |
+
|
21 |
+
|
22 |
+
def get_nightly():
|
23 |
+
import requests
|
24 |
+
r = requests.get(api_url, timeout=10)
|
25 |
+
if r.status_code != 200:
|
26 |
+
print('Failed to get nightly version')
|
27 |
+
return None
|
28 |
+
json = r.json()
|
29 |
+
assets = json.get('assets', [])
|
30 |
+
if len(assets) == 0:
|
31 |
+
print('Failed to get nightly version')
|
32 |
+
return None
|
33 |
+
asset = assets[0].get('name', '')
|
34 |
+
pattern = r"-(.+?)\+"
|
35 |
+
match = re.search(pattern, asset)
|
36 |
+
if match:
|
37 |
+
ver = match.group(1)
|
38 |
+
print(f'Nightly version: {ver}')
|
39 |
+
return ver
|
40 |
+
else:
|
41 |
+
print('Failed to get nightly version')
|
42 |
+
return None
|
43 |
+
|
44 |
+
|
45 |
+
def install_stable_fast():
|
46 |
+
import torch
|
47 |
+
|
48 |
+
python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
|
49 |
+
if python_ver not in python_supported:
|
50 |
+
raise ValueError(f'StableFast unsupported python: {python_ver} required {python_supported}')
|
51 |
+
if sys.platform == 'linux':
|
52 |
+
bin_url = 'manylinux2014_x86_64.whl'
|
53 |
+
elif sys.platform == 'win32':
|
54 |
+
bin_url = 'win_amd64.whl'
|
55 |
+
else:
|
56 |
+
raise ValueError(f'StableFast unsupported platform: {sys.platform}')
|
57 |
+
|
58 |
+
torch_ver, cuda_ver = torch.__version__.split('+')
|
59 |
+
torch_ver = torch_ver.replace('.', '')
|
60 |
+
sf_ver = get_nightly()
|
61 |
+
|
62 |
+
if torch_ver not in torch_supported:
|
63 |
+
print(f'StableFast unsupported torch: {torch_ver} required {torch_supported}')
|
64 |
+
print('Installing from source...')
|
65 |
+
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
|
66 |
+
elif cuda_ver not in cuda_supported:
|
67 |
+
print(f'StableFast unsupported CUDA: {cuda_ver} required {cuda_supported}')
|
68 |
+
print('Installing from source...')
|
69 |
+
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
|
70 |
+
elif sf_ver is None:
|
71 |
+
print('StableFast cannot determine version')
|
72 |
+
print('Installing from source...')
|
73 |
+
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
|
74 |
+
else:
|
75 |
+
print('Installing wheel...')
|
76 |
+
file_url = f'stable_fast-{sf_ver}+torch{torch_ver}{cuda_ver}-cp{python_ver}-cp{python_ver}-{bin_url}'
|
77 |
+
url = f'{repo_url}/{path_url}/{file_url}'
|
78 |
+
|
79 |
+
ok = install_pip(url)
|
80 |
+
if ok:
|
81 |
+
import sfast
|
82 |
+
print(f'StableFast installed: {sfast.__version__}')
|
83 |
+
else:
|
84 |
+
print('StableFast install failed')
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
install_stable_fast()
|
cli/latents.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import json
|
6 |
+
import pathlib
|
7 |
+
import argparse
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision import transforms
|
15 |
+
from tqdm import tqdm
|
16 |
+
from util import Map
|
17 |
+
|
18 |
+
from rich.pretty import install as pretty_install
|
19 |
+
from rich.traceback import install as traceback_install
|
20 |
+
from rich.console import Console
|
21 |
+
|
22 |
+
console = Console(log_time=True, log_time_format='%H:%M:%S-%f')
|
23 |
+
pretty_install(console=console)
|
24 |
+
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False)
|
25 |
+
|
26 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'modules', 'lora'))
|
27 |
+
import library.model_util as model_util
|
28 |
+
import library.train_util as train_util
|
29 |
+
|
30 |
+
warnings.filterwarnings('ignore')
|
31 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
options = Map({
|
33 |
+
'batch': 1,
|
34 |
+
'input': '',
|
35 |
+
'json': '',
|
36 |
+
'max': 1024,
|
37 |
+
'min': 256,
|
38 |
+
'noupscale': False,
|
39 |
+
'precision': 'fp32',
|
40 |
+
'resolution': '512,512',
|
41 |
+
'steps': 64,
|
42 |
+
'vae': 'stabilityai/sd-vae-ft-mse'
|
43 |
+
})
|
44 |
+
vae = None
|
45 |
+
|
46 |
+
|
47 |
+
def get_latents(local_vae, images, weight_dtype):
|
48 |
+
image_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ])
|
49 |
+
img_tensors = [image_transforms(image) for image in images]
|
50 |
+
img_tensors = torch.stack(img_tensors)
|
51 |
+
img_tensors = img_tensors.to(device, weight_dtype)
|
52 |
+
with torch.no_grad():
|
53 |
+
latents = local_vae.encode(img_tensors).latent_dist.sample().float().to('cpu').numpy()
|
54 |
+
return latents, [images[0].shape[0], images[0].shape[1]]
|
55 |
+
|
56 |
+
|
57 |
+
def get_npz_filename_wo_ext(data_dir, image_key):
|
58 |
+
return os.path.join(data_dir, os.path.splitext(os.path.basename(image_key))[0])
|
59 |
+
|
60 |
+
|
61 |
+
def create_vae_latents(local_params):
|
62 |
+
args = Map({**options, **local_params})
|
63 |
+
console.log(f'create vae latents args: {args}')
|
64 |
+
image_paths = train_util.glob_images(args.input)
|
65 |
+
if os.path.exists(args.json):
|
66 |
+
with open(args.json, 'rt', encoding='utf-8') as f:
|
67 |
+
metadata = json.load(f)
|
68 |
+
else:
|
69 |
+
return
|
70 |
+
if args.precision == 'fp16':
|
71 |
+
weight_dtype = torch.float16
|
72 |
+
elif args.precision == 'bf16':
|
73 |
+
weight_dtype = torch.bfloat16
|
74 |
+
else:
|
75 |
+
weight_dtype = torch.float32
|
76 |
+
global vae # pylint: disable=global-statement
|
77 |
+
if vae is None:
|
78 |
+
vae = model_util.load_vae(args.vae, weight_dtype)
|
79 |
+
vae.eval()
|
80 |
+
vae.to(device, dtype=weight_dtype)
|
81 |
+
max_reso = tuple([int(t) for t in args.resolution.split(',')])
|
82 |
+
assert len(max_reso) == 2, f'illegal resolution: {args.resolution}'
|
83 |
+
bucket_manager = train_util.BucketManager(args.noupscale, max_reso, args.min, args.max, args.steps)
|
84 |
+
if not args.noupscale:
|
85 |
+
bucket_manager.make_buckets()
|
86 |
+
img_ar_errors = []
|
87 |
+
def process_batch(is_last):
|
88 |
+
for bucket in bucket_manager.buckets:
|
89 |
+
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch:
|
90 |
+
latents, original_size = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
91 |
+
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, f'latent shape {latents.shape}, {bucket[0][1].shape}'
|
92 |
+
for (image_key, _), latent in zip(bucket, latents):
|
93 |
+
npz_file_name = get_npz_filename_wo_ext(args.input, image_key)
|
94 |
+
# np.savez(npz_file_name, latent)
|
95 |
+
kwargs = {}
|
96 |
+
np.savez(
|
97 |
+
npz_file_name,
|
98 |
+
latents=latent,
|
99 |
+
original_size=np.array(original_size),
|
100 |
+
crop_ltrb=np.array([0, 0]),
|
101 |
+
**kwargs,
|
102 |
+
)
|
103 |
+
bucket.clear()
|
104 |
+
data = [[(None, ip)] for ip in image_paths]
|
105 |
+
bucket_counts = {}
|
106 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
107 |
+
if data_entry[0] is None:
|
108 |
+
continue
|
109 |
+
img_tensor, image_path = data_entry[0]
|
110 |
+
if img_tensor is not None:
|
111 |
+
image = transforms.functional.to_pil_image(img_tensor)
|
112 |
+
else:
|
113 |
+
image = Image.open(image_path)
|
114 |
+
image_key = os.path.basename(image_path)
|
115 |
+
image_key = os.path.join(os.path.basename(pathlib.Path(image_path).parent), pathlib.Path(image_path).stem)
|
116 |
+
if image_key not in metadata:
|
117 |
+
metadata[image_key] = {}
|
118 |
+
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
119 |
+
img_ar_errors.append(abs(ar_error))
|
120 |
+
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
121 |
+
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
122 |
+
if not args.noupscale:
|
123 |
+
assert resized_size[0] == reso[0] or resized_size[1] == reso[1], f'internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}'
|
124 |
+
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[1], f'internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}'
|
125 |
+
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[1], f'internal error resized size is small: {resized_size}, {reso}'
|
126 |
+
image = np.array(image)
|
127 |
+
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]:
|
128 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
129 |
+
if resized_size[0] > reso[0]:
|
130 |
+
trim_size = resized_size[0] - reso[0]
|
131 |
+
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
132 |
+
if resized_size[1] > reso[1]:
|
133 |
+
trim_size = resized_size[1] - reso[1]
|
134 |
+
image = image[trim_size//2:trim_size//2 + reso[1]]
|
135 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f'internal error, illegal trimmed size: {image.shape}, {reso}'
|
136 |
+
bucket_manager.add_image(reso, (image_key, image))
|
137 |
+
process_batch(False)
|
138 |
+
|
139 |
+
process_batch(True)
|
140 |
+
vae.to('cpu')
|
141 |
+
|
142 |
+
bucket_manager.sort()
|
143 |
+
img_ar_errors = np.array(img_ar_errors)
|
144 |
+
for i, reso in enumerate(bucket_manager.resos):
|
145 |
+
count = bucket_counts.get(reso, 0)
|
146 |
+
if count > 0:
|
147 |
+
console.log(f'vae latents bucket: {i+1}/{len(bucket_manager.resos)} resolution: {reso} images: {count} mean-ar-error: {np.mean(img_ar_errors)}')
|
148 |
+
with open(args.json, 'wt', encoding='utf-8') as f:
|
149 |
+
json.dump(metadata, f, indent=2)
|
150 |
+
|
151 |
+
|
152 |
+
def unload_vae():
|
153 |
+
global vae # pylint: disable=global-statement
|
154 |
+
vae = None
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
parser = argparse.ArgumentParser()
|
159 |
+
parser.add_argument('input', type=str, help='directory for train images')
|
160 |
+
parser.add_argument('--json', type=str, required=True, help='metadata file to input')
|
161 |
+
parser.add_argument('--vae', type=str, required=True, help='model name or path to encode latents')
|
162 |
+
parser.add_argument('--batch', type=int, default=1, help='batch size in inference')
|
163 |
+
parser.add_argument('--resolution', type=str, default='512,512', help='max resolution in fine tuning (width,height)')
|
164 |
+
parser.add_argument('--min', type=int, default=256, help='minimum resolution for buckets')
|
165 |
+
parser.add_argument('--max', type=int, default=1024, help='maximum resolution for buckets')
|
166 |
+
parser.add_argument('--steps', type=int, default=64, help='steps of resolution for buckets, divisible by 8')
|
167 |
+
parser.add_argument('--noupscale', action='store_true', help='make bucket for each image without upscaling')
|
168 |
+
parser.add_argument('--precision', type=str, default='fp32', choices=['fp32', 'fp16', 'bf16'], help='use precision')
|
169 |
+
params = parser.parse_args()
|
170 |
+
create_vae_latents(vars(params))
|
cli/lcm-convert.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image, LCMScheduler
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser("lcm_convert")
|
7 |
+
parser.add_argument("--name", help="Name of the new LCM model", type=str)
|
8 |
+
parser.add_argument("--model", help="A model to convert", type=str)
|
9 |
+
parser.add_argument("--lora-scale", default=1.0, help="Strenght of the LCM", type=float)
|
10 |
+
parser.add_argument("--huggingface", action="store_true", help="Use Hugging Face models instead of safetensors models")
|
11 |
+
parser.add_argument("--upload", action="store_true", help="Upload the new LCM model to Hugging Face")
|
12 |
+
parser.add_argument("--no-half", action="store_true", help="Convert the new LCM model to FP32")
|
13 |
+
parser.add_argument("--no-save", action="store_true", help="Don't save the new LCM model to local disk")
|
14 |
+
parser.add_argument("--sdxl", action="store_true", help="Use SDXL models")
|
15 |
+
parser.add_argument("--ssd-1b", action="store_true", help="Use SSD-1B models")
|
16 |
+
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
if args.huggingface:
|
20 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.float16, variant="fp16")
|
21 |
+
else:
|
22 |
+
if args.sdxl or args.ssd_1b:
|
23 |
+
pipeline = StableDiffusionXLPipeline.from_single_file(args.model)
|
24 |
+
else:
|
25 |
+
pipeline = StableDiffusionPipeline.from_single_file(args.model)
|
26 |
+
|
27 |
+
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
28 |
+
if args.sdxl:
|
29 |
+
pipeline.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
30 |
+
elif args.ssd_1b:
|
31 |
+
pipeline.load_lora_weights("latent-consistency/lcm-lora-ssd-1b")
|
32 |
+
else:
|
33 |
+
pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
|
34 |
+
pipeline.fuse_lora(lora_scale=args.lora_scale)
|
35 |
+
|
36 |
+
#components = pipeline.components
|
37 |
+
#pipeline = LatentConsistencyModelPipeline(**components)
|
38 |
+
|
39 |
+
if args.no_half:
|
40 |
+
pipeline = pipeline.to(dtype=torch.float32)
|
41 |
+
else:
|
42 |
+
pipeline = pipeline.to(dtype=torch.float16)
|
43 |
+
print(pipeline)
|
44 |
+
|
45 |
+
if not args.no_save:
|
46 |
+
os.makedirs(f"models--local--{args.name}/snapshots")
|
47 |
+
if args.no_half:
|
48 |
+
pipeline.save_pretrained(f"models--local--{args.name}/snapshots/{args.name}")
|
49 |
+
else:
|
50 |
+
pipeline.save_pretrained(f"models--local--{args.name}/snapshots/{args.name}", variant="fp16")
|
51 |
+
if args.upload:
|
52 |
+
if args.no_half:
|
53 |
+
pipeline.push_to_hub(args.name)
|
54 |
+
else:
|
55 |
+
pipeline.push_to_hub(args.name, variant="fp16")
|
cli/model-jit.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import functools
|
5 |
+
import argparse
|
6 |
+
import logging
|
7 |
+
import warnings
|
8 |
+
from dataclasses import dataclass
|
9 |
+
|
10 |
+
logging.getLogger("DeepSpeed").disabled = True
|
11 |
+
warnings.filterwarnings(action="ignore", category=FutureWarning)
|
12 |
+
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import diffusers
|
16 |
+
|
17 |
+
n_warmup = 5
|
18 |
+
n_traces = 10
|
19 |
+
n_runs = 100
|
20 |
+
args = {}
|
21 |
+
pipe = None
|
22 |
+
log = logging.getLogger("sd")
|
23 |
+
|
24 |
+
|
25 |
+
def setup_logging():
|
26 |
+
from rich.theme import Theme
|
27 |
+
from rich.logging import RichHandler
|
28 |
+
from rich.console import Console
|
29 |
+
from rich.traceback import install
|
30 |
+
log.setLevel(logging.DEBUG)
|
31 |
+
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", "traceback.border.syntax_error": "black", "inspect.value.border": "black" }))
|
32 |
+
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
|
33 |
+
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG, console=console)
|
34 |
+
rh.setLevel(logging.DEBUG)
|
35 |
+
log.addHandler(rh)
|
36 |
+
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
37 |
+
logging.getLogger("torch").setLevel(logging.ERROR)
|
38 |
+
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning)
|
39 |
+
install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
|
40 |
+
|
41 |
+
|
42 |
+
def generate_inputs():
|
43 |
+
if args.type == 'sd15':
|
44 |
+
sample = torch.randn(2, 4, 64, 64).half().cuda()
|
45 |
+
timestep = torch.rand(1).half().cuda() * 999
|
46 |
+
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
|
47 |
+
return sample, timestep, encoder_hidden_states
|
48 |
+
if args.type == 'sdxl':
|
49 |
+
sample = torch.randn(2, 4, 64, 64).half().cuda()
|
50 |
+
timestep = torch.rand(1).half().cuda() * 999
|
51 |
+
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
|
52 |
+
text_embeds = torch.randn(1, 77, 2048).half().cuda()
|
53 |
+
return sample, timestep, encoder_hidden_states, text_embeds
|
54 |
+
|
55 |
+
|
56 |
+
def load_model():
|
57 |
+
log.info(f'versions: torch={torch.__version__} diffusers={diffusers.__version__}')
|
58 |
+
diffusers_load_config = {
|
59 |
+
"low_cpu_mem_usage": True,
|
60 |
+
"torch_dtype": torch.float16,
|
61 |
+
"safety_checker": None,
|
62 |
+
"requires_safety_checker": False,
|
63 |
+
"load_safety_checker": False,
|
64 |
+
"load_connected_pipeline": True,
|
65 |
+
"use_safetensors": True,
|
66 |
+
}
|
67 |
+
pipeline = diffusers.StableDiffusionPipeline if args.type == 'sd15' else diffusers.StableDiffusionXLPipeline
|
68 |
+
global pipe # pylint: disable=global-statement
|
69 |
+
t0 = time.time()
|
70 |
+
pipe = pipeline.from_single_file(args.model, **diffusers_load_config).to('cuda')
|
71 |
+
size = os.path.getsize(args.model)
|
72 |
+
log.info(f'load: model={args.model} type={args.type} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
|
73 |
+
|
74 |
+
|
75 |
+
def load_trace(fn: str):
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class UNet2DConditionOutput:
|
79 |
+
sample: torch.FloatTensor
|
80 |
+
|
81 |
+
class TracedUNet(torch.nn.Module):
|
82 |
+
def __init__(self):
|
83 |
+
super().__init__()
|
84 |
+
self.in_channels = pipe.unet.in_channels
|
85 |
+
self.device = pipe.unet.device
|
86 |
+
|
87 |
+
def forward(self, latent_model_input, t, encoder_hidden_states):
|
88 |
+
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
89 |
+
return UNet2DConditionOutput(sample=sample)
|
90 |
+
|
91 |
+
t0 = time.time()
|
92 |
+
unet_traced = torch.jit.load(fn)
|
93 |
+
pipe.unet = TracedUNet()
|
94 |
+
size = os.path.getsize(fn)
|
95 |
+
log.info(f'load: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
|
96 |
+
|
97 |
+
|
98 |
+
def trace_model():
|
99 |
+
log.info(f'tracing model: {args.model}')
|
100 |
+
torch.set_grad_enabled(False)
|
101 |
+
unet = pipe.unet
|
102 |
+
unet.eval()
|
103 |
+
# unet.to(memory_format=torch.channels_last) # use channels_last memory format
|
104 |
+
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
|
105 |
+
|
106 |
+
# warmup
|
107 |
+
t0 = time.time()
|
108 |
+
for _ in range(n_warmup):
|
109 |
+
with torch.inference_mode():
|
110 |
+
inputs = generate_inputs()
|
111 |
+
_output = unet(*inputs)
|
112 |
+
log.info(f'warmup: time={time.time() - t0:.3f}s passes={n_warmup}')
|
113 |
+
|
114 |
+
# trace
|
115 |
+
t0 = time.time()
|
116 |
+
unet_traced = torch.jit.trace(unet, inputs, check_trace=True)
|
117 |
+
unet_traced.eval()
|
118 |
+
log.info(f'trace: time={time.time() - t0:.3f}s')
|
119 |
+
|
120 |
+
# optimize graph
|
121 |
+
t0 = time.time()
|
122 |
+
for _ in range(n_traces):
|
123 |
+
with torch.inference_mode():
|
124 |
+
inputs = generate_inputs()
|
125 |
+
_output = unet_traced(*inputs)
|
126 |
+
log.info(f'optimize: time={time.time() - t0:.3f}s passes={n_traces}')
|
127 |
+
|
128 |
+
# save the model
|
129 |
+
if args.save:
|
130 |
+
t0 = time.time()
|
131 |
+
basename, _ext = os.path.splitext(args.model)
|
132 |
+
fn = f"{basename}.pt"
|
133 |
+
unet_traced.save(fn)
|
134 |
+
size = os.path.getsize(fn)
|
135 |
+
log.info(f'save: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
|
136 |
+
return fn
|
137 |
+
|
138 |
+
pipe.unet = unet_traced
|
139 |
+
return None
|
140 |
+
|
141 |
+
|
142 |
+
def benchmark_model(msg: str):
|
143 |
+
with torch.inference_mode():
|
144 |
+
inputs = generate_inputs()
|
145 |
+
torch.cuda.synchronize()
|
146 |
+
for n in range(n_runs):
|
147 |
+
if n > n_runs / 10:
|
148 |
+
t0 = time.time()
|
149 |
+
_output = pipe.unet(*inputs)
|
150 |
+
torch.cuda.synchronize()
|
151 |
+
t1 = time.time()
|
152 |
+
log.info(f"benchmark unet: {t1 - t0:.3f}s passes={n_runs} type={msg}")
|
153 |
+
return t1 - t0
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == '__main__':
|
157 |
+
parser = argparse.ArgumentParser(description = 'SD.Next')
|
158 |
+
parser.add_argument('--model', type=str, default='', required=True, help='model path')
|
159 |
+
parser.add_argument('--type', type=str, default='sd15', choices=['sd15', 'sdxl'], required=False, help='model type, default: %(default)s')
|
160 |
+
parser.add_argument('--benchmark', default = False, action='store_true', help = "run benchmarks, default: %(default)s")
|
161 |
+
parser.add_argument('--trace', default = True, action='store_true', help = "run jit tracing, default: %(default)s")
|
162 |
+
parser.add_argument('--save', default = False, action='store_true', help = "save optimized unet, default: %(default)s")
|
163 |
+
args = parser.parse_args()
|
164 |
+
setup_logging()
|
165 |
+
log.info('sdnext model jit tracing')
|
166 |
+
if not os.path.isfile(args.model):
|
167 |
+
log.error(f"invalid model path: {args.model}")
|
168 |
+
exit(1)
|
169 |
+
load_model()
|
170 |
+
if args.benchmark:
|
171 |
+
time0 = benchmark_model('original')
|
172 |
+
unet_saved = trace_model()
|
173 |
+
if unet_saved is not None:
|
174 |
+
load_trace(unet_saved)
|
175 |
+
if args.benchmark:
|
176 |
+
time1 = benchmark_model('traced')
|
177 |
+
log.info(f'benchmark speedup: {100 * (time0 - time1) / time0:.3f}%')
|
cli/model-metadata.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
from rich import print # pylint: disable=redefined-builtin
|
6 |
+
|
7 |
+
|
8 |
+
def read_metadata(fn):
|
9 |
+
res = {}
|
10 |
+
with open(fn, mode="rb") as f:
|
11 |
+
metadata_len = f.read(8)
|
12 |
+
metadata_len = int.from_bytes(metadata_len, "little")
|
13 |
+
json_start = f.read(2)
|
14 |
+
if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
|
15 |
+
print(f"Not a valid safetensors file: {fn}")
|
16 |
+
json_data = json_start + f.read(metadata_len-2)
|
17 |
+
json_obj = json.loads(json_data)
|
18 |
+
for k, v in json_obj.get("__metadata__", {}).items():
|
19 |
+
res[k] = v
|
20 |
+
if isinstance(v, str) and v[0:1] == '{':
|
21 |
+
try:
|
22 |
+
res[k] = json.loads(v)
|
23 |
+
except Exception:
|
24 |
+
pass
|
25 |
+
print(f"{fn}: {json.dumps(res, indent=4)}")
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
if len(sys.argv) == 0:
|
30 |
+
print('metadata:', 'no files specified')
|
31 |
+
for fn in sys.argv:
|
32 |
+
if os.path.isfile(fn):
|
33 |
+
read_metadata(fn)
|
34 |
+
elif os.path.isdir(fn):
|
35 |
+
for root, _dirs, files in os.walk(fn):
|
36 |
+
for file in files:
|
37 |
+
read_metadata(os.path.join(root, file))
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
sys.argv.pop(0)
|
41 |
+
main()
|
cli/nvidia-smi.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import xmltodict
|
7 |
+
from rich import print # pylint: disable=redefined-builtin
|
8 |
+
from util import log, Map
|
9 |
+
|
10 |
+
|
11 |
+
def get_nvidia_smi(output='dict'):
|
12 |
+
smi = shutil.which('nvidia-smi')
|
13 |
+
if smi is None:
|
14 |
+
log.error("nvidia-smi not found")
|
15 |
+
return None
|
16 |
+
result = subprocess.run(f'"{smi}" -q -x', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
17 |
+
xml = result.stdout.decode(encoding="utf8", errors="ignore")
|
18 |
+
d = xmltodict.parse(xml)
|
19 |
+
if 'nvidia_smi_log' in d:
|
20 |
+
d = d['nvidia_smi_log']
|
21 |
+
if 'gpu' in d and 'supported_clocks' in d['gpu']:
|
22 |
+
del d['gpu']['supported_clocks']
|
23 |
+
if output == 'dict':
|
24 |
+
return d
|
25 |
+
elif output == 'class' or output == 'map':
|
26 |
+
d = Map(d)
|
27 |
+
return d
|
28 |
+
elif output == 'json':
|
29 |
+
return json.dumps(d, indent=4)
|
30 |
+
return None
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
res = get_nvidia_smi(output='dict')
|
35 |
+
print(type(res), res)
|
cli/options.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from util import Map
|
2 |
+
|
3 |
+
embedding = Map({
|
4 |
+
"id_task": 0,
|
5 |
+
"embedding_name": "",
|
6 |
+
"learn_rate": -1,
|
7 |
+
"batch_size": 1,
|
8 |
+
"steps": 500,
|
9 |
+
"data_root": "",
|
10 |
+
"log_directory": "train/log",
|
11 |
+
"template_filename": "subject_filewords.txt",
|
12 |
+
"gradient_step": 20,
|
13 |
+
"training_width": 512,
|
14 |
+
"training_height": 512,
|
15 |
+
"shuffle_tags": False,
|
16 |
+
"tag_drop_out": 0,
|
17 |
+
"clip_grad_mode": "disabled",
|
18 |
+
"clip_grad_value": "0.1",
|
19 |
+
"latent_sampling_method": "deterministic",
|
20 |
+
"create_image_every": 0,
|
21 |
+
"save_embedding_every": 0,
|
22 |
+
"save_image_with_stored_embedding": False,
|
23 |
+
"preview_from_txt2img": False,
|
24 |
+
"preview_prompt": "",
|
25 |
+
"preview_negative_prompt": "blurry, duplicate, ugly, deformed, low res, watermark, text",
|
26 |
+
"preview_steps": 20,
|
27 |
+
"preview_sampler_index": 0,
|
28 |
+
"preview_cfg_scale": 6,
|
29 |
+
"preview_seed": -1,
|
30 |
+
"preview_width": 512,
|
31 |
+
"preview_height": 512,
|
32 |
+
"varsize": False,
|
33 |
+
"use_weight": False,
|
34 |
+
})
|
35 |
+
|
36 |
+
lora = Map({
|
37 |
+
"bucket_no_upscale": False,
|
38 |
+
"bucket_reso_steps": 64,
|
39 |
+
"cache_latents": True,
|
40 |
+
"caption_dropout_every_n_epochs": None,
|
41 |
+
"caption_dropout_rate": 0.0,
|
42 |
+
"caption_extension": ".txt",
|
43 |
+
"caption_extention": ".txt",
|
44 |
+
"caption_tag_dropout_rate": 0.0,
|
45 |
+
"clip_skip": None,
|
46 |
+
"color_aug": False,
|
47 |
+
"dataset_repeats": 1,
|
48 |
+
"debug_dataset": False,
|
49 |
+
"enable_bucket": False,
|
50 |
+
"face_crop_aug_range": None,
|
51 |
+
"flip_aug": False,
|
52 |
+
"full_fp16": False,
|
53 |
+
"gradient_accumulation_steps": 1,
|
54 |
+
"gradient_checkpointing": False,
|
55 |
+
"in_json": "",
|
56 |
+
"keep_tokens": None,
|
57 |
+
"learning_rate": 5e-05,
|
58 |
+
"log_prefix": None,
|
59 |
+
"logging_dir": None,
|
60 |
+
"lr_scheduler_num_cycles": 1,
|
61 |
+
"lr_scheduler_power": 1,
|
62 |
+
"lr_scheduler": "cosine",
|
63 |
+
"lr_warmup_steps": 0,
|
64 |
+
"max_bucket_reso": 1024,
|
65 |
+
"max_data_loader_n_workers": 8,
|
66 |
+
"max_grad_norm": 0.0,
|
67 |
+
"max_token_length": None,
|
68 |
+
"max_train_epochs": None,
|
69 |
+
"max_train_steps": 2500,
|
70 |
+
"mem_eff_attn": False,
|
71 |
+
"min_bucket_reso": 256,
|
72 |
+
"mixed_precision": "fp16",
|
73 |
+
"network_alpha": 1.0,
|
74 |
+
"network_args": None,
|
75 |
+
"network_dim": 16,
|
76 |
+
"network_module": "networks.lora",
|
77 |
+
"network_train_text_encoder_only": False,
|
78 |
+
"network_train_unet_only": False,
|
79 |
+
"network_weights": None,
|
80 |
+
"no_metadata": False,
|
81 |
+
"output_dir": "",
|
82 |
+
"output_name": "",
|
83 |
+
"persistent_data_loader_workers": False,
|
84 |
+
"pretrained_model_name_or_path": "",
|
85 |
+
"prior_loss_weight": 1.0,
|
86 |
+
"random_crop": False,
|
87 |
+
"reg_data_dir": None,
|
88 |
+
"resolution": "512,512",
|
89 |
+
"resume": None,
|
90 |
+
"save_every_n_epochs": None,
|
91 |
+
"save_last_n_epochs_state": None,
|
92 |
+
"save_last_n_epochs": None,
|
93 |
+
"save_model_as": "ckpt",
|
94 |
+
"save_n_epoch_ratio": None,
|
95 |
+
"save_precision": "fp16",
|
96 |
+
"save_state": False,
|
97 |
+
"seed": 42,
|
98 |
+
"shuffle_caption": False,
|
99 |
+
"text_encoder_lr": 5e-05,
|
100 |
+
"train_batch_size": 1,
|
101 |
+
"train_data_dir": "",
|
102 |
+
"training_comment": "",
|
103 |
+
"unet_lr": 1e-04,
|
104 |
+
"use_8bit_adam": False,
|
105 |
+
"v_parameterization": False,
|
106 |
+
"v2": False,
|
107 |
+
"vae": None,
|
108 |
+
"xformers": False,
|
109 |
+
})
|
110 |
+
|
111 |
+
process = Map({
|
112 |
+
# general settings, do not modify
|
113 |
+
'format': '.jpg', # image format
|
114 |
+
'target_size': 512, # target resolution
|
115 |
+
'segmentation_model': 0, # segmentation model 0/general 1/landscape
|
116 |
+
'segmentation_background': (192, 192, 192), # segmentation background color
|
117 |
+
'blur_score': 1.8, # max score for face blur detection
|
118 |
+
'blur_samplesize': 60, # sample size to use for blur detection
|
119 |
+
'similarity_score': 0.8, # maximum similarity score before image is discarded
|
120 |
+
'similarity_size': 64, # base similarity detection on reduced images
|
121 |
+
'range_score': 0.15, # min score for face color dynamicrange detection
|
122 |
+
# face processing settings
|
123 |
+
'face_score': 0.7, # min face detection score
|
124 |
+
'face_pad': 0.1, # pad face image percentage
|
125 |
+
'face_model': 1, # which face model to use 0/close-up 1/standard
|
126 |
+
# body processing settings
|
127 |
+
'body_score': 0.9, # min body detection score
|
128 |
+
'body_visibility': 0.5, # min visibility score for each detected body part
|
129 |
+
'body_parts': 15, # min number of detected body parts with sufficient visibility
|
130 |
+
'body_pad': 0.2, # pad body image percentage
|
131 |
+
'body_model': 2, # body model to use 0/low 1/medium 2/high
|
132 |
+
# similarity detection settings
|
133 |
+
# interrogate settings
|
134 |
+
'interrogate': False, # interrogate images
|
135 |
+
'interrogate_model': ['clip', 'deepdanbooru'], # interrogate models
|
136 |
+
'tag_limit': 5, # number of tags to extract
|
137 |
+
# validations
|
138 |
+
# tbd
|
139 |
+
'face_segmentation': False, # segmentation enabled
|
140 |
+
'body_segmentation': False, # segmentation enabled
|
141 |
+
})
|
cli/process.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=global-statement
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import math
|
5 |
+
import base64
|
6 |
+
import numpy as np
|
7 |
+
import mediapipe as mp
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
from pi_heif import register_heif_opener
|
10 |
+
from skimage.metrics import structural_similarity as ssim
|
11 |
+
from scipy.stats import beta
|
12 |
+
|
13 |
+
import util
|
14 |
+
import sdapi
|
15 |
+
import options
|
16 |
+
|
17 |
+
face_model = None
|
18 |
+
body_model = None
|
19 |
+
segmentation_model = None
|
20 |
+
all_images = []
|
21 |
+
all_images_by_type = {}
|
22 |
+
|
23 |
+
|
24 |
+
class Result():
|
25 |
+
def __init__(self, typ: str, fn: str, tag: str = None, requested: list = []):
|
26 |
+
self.type = typ
|
27 |
+
self.input = fn
|
28 |
+
self.output = ''
|
29 |
+
self.basename = ''
|
30 |
+
self.message = ''
|
31 |
+
self.image = None
|
32 |
+
self.caption = ''
|
33 |
+
self.tag = tag
|
34 |
+
self.tags = []
|
35 |
+
self.ops = []
|
36 |
+
self.steps = requested
|
37 |
+
|
38 |
+
|
39 |
+
def detect_blur(image: Image):
|
40 |
+
# based on <https://github.com/karthik9319/Blur-Detection/>
|
41 |
+
bw = ImageOps.grayscale(image)
|
42 |
+
cx, cy = image.size[0] // 2, image.size[1] // 2
|
43 |
+
fft = np.fft.fft2(bw)
|
44 |
+
fftShift = np.fft.fftshift(fft)
|
45 |
+
fftShift[cy - options.process.blur_samplesize: cy + options.process.blur_samplesize, cx - options.process.blur_samplesize: cx + options.process.blur_samplesize] = 0
|
46 |
+
fftShift = np.fft.ifftshift(fftShift)
|
47 |
+
recon = np.fft.ifft2(fftShift)
|
48 |
+
magnitude = np.log(np.abs(recon))
|
49 |
+
mean = round(np.mean(magnitude), 2)
|
50 |
+
return mean
|
51 |
+
|
52 |
+
|
53 |
+
def detect_dynamicrange(image: Image):
|
54 |
+
# based on <https://towardsdatascience.com/measuring-enhancing-image-quality-attributes-234b0f250e10>
|
55 |
+
data = np.asarray(image)
|
56 |
+
image = np.float32(data)
|
57 |
+
RGB = [0.299, 0.587, 0.114]
|
58 |
+
height, width = image.shape[:2] # pylint: disable=unsubscriptable-object
|
59 |
+
brightness_image = np.sqrt(image[..., 0] ** 2 * RGB[0] + image[..., 1] ** 2 * RGB[1] + image[..., 2] ** 2 * RGB[2]) # pylint: disable=unsubscriptable-object
|
60 |
+
hist, _ = np.histogram(brightness_image, bins=256, range=(0, 255))
|
61 |
+
img_brightness_pmf = hist / (height * width)
|
62 |
+
dist = beta(2, 2)
|
63 |
+
ys = dist.pdf(np.linspace(0, 1, 256))
|
64 |
+
ref_pmf = ys / np.sum(ys)
|
65 |
+
dot_product = np.dot(ref_pmf, img_brightness_pmf)
|
66 |
+
squared_dist_a = np.sum(ref_pmf ** 2)
|
67 |
+
squared_dist_b = np.sum(img_brightness_pmf ** 2)
|
68 |
+
res = dot_product / math.sqrt(squared_dist_a * squared_dist_b)
|
69 |
+
return round(res, 2)
|
70 |
+
|
71 |
+
|
72 |
+
def detect_simmilar(image: Image):
|
73 |
+
img = image.resize((options.process.similarity_size, options.process.similarity_size))
|
74 |
+
img = ImageOps.grayscale(img)
|
75 |
+
data = np.array(img)
|
76 |
+
similarity = 0
|
77 |
+
for i in all_images:
|
78 |
+
val = ssim(data, i, data_range=255, channel_axis=None, gradient=False, full=False)
|
79 |
+
if val > similarity:
|
80 |
+
similarity = val
|
81 |
+
all_images.append(data)
|
82 |
+
return similarity
|
83 |
+
|
84 |
+
|
85 |
+
def segmentation(res: Result):
|
86 |
+
global segmentation_model
|
87 |
+
if segmentation_model is None:
|
88 |
+
segmentation_model = mp.solutions.selfie_segmentation.SelfieSegmentation(model_selection=options.process.segmentation_model)
|
89 |
+
data = np.array(res.image)
|
90 |
+
results = segmentation_model.process(data)
|
91 |
+
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
|
92 |
+
background = np.zeros(data.shape, dtype=np.uint8)
|
93 |
+
background[:] = options.process.segmentation_background
|
94 |
+
data = np.where(condition, data, background) # consider using a joint bilateral filter instead of pure combine
|
95 |
+
segmented = Image.fromarray(data)
|
96 |
+
res.image = segmented
|
97 |
+
res.ops.append('segmentation')
|
98 |
+
return res
|
99 |
+
|
100 |
+
|
101 |
+
def unload():
|
102 |
+
global face_model
|
103 |
+
if face_model is not None:
|
104 |
+
face_model = None
|
105 |
+
global body_model
|
106 |
+
if body_model is not None:
|
107 |
+
body_model = None
|
108 |
+
global segmentation_model
|
109 |
+
if segmentation_model is not None:
|
110 |
+
segmentation_model = None
|
111 |
+
|
112 |
+
|
113 |
+
def encode(img):
|
114 |
+
with io.BytesIO() as stream:
|
115 |
+
img.save(stream, 'JPEG')
|
116 |
+
values = stream.getvalue()
|
117 |
+
encoded = base64.b64encode(values).decode()
|
118 |
+
return encoded
|
119 |
+
|
120 |
+
|
121 |
+
def reset():
|
122 |
+
unload()
|
123 |
+
global all_images_by_type
|
124 |
+
all_images_by_type = {}
|
125 |
+
global all_images
|
126 |
+
all_images = []
|
127 |
+
|
128 |
+
|
129 |
+
def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = False):
|
130 |
+
kwargs = util.Map({
|
131 |
+
'image': encode(res.image),
|
132 |
+
'codeformer_visibility': 0.0,
|
133 |
+
'codeformer_weight': 0.0,
|
134 |
+
})
|
135 |
+
if res.image.width >= options.process.target_size and res.image.height >= options.process.target_size:
|
136 |
+
upscale = False
|
137 |
+
if upscale:
|
138 |
+
kwargs.upscaler_1 = 'SwinIR_4x'
|
139 |
+
kwargs.upscaling_resize = 2
|
140 |
+
res.ops.append('upscale')
|
141 |
+
if restore:
|
142 |
+
kwargs.codeformer_visibility = 1.0
|
143 |
+
kwargs.codeformer_weight = 0.2
|
144 |
+
res.ops.append('restore')
|
145 |
+
if upscale or restore:
|
146 |
+
result = sdapi.postsync('/sdapi/v1/extra-single-image', kwargs)
|
147 |
+
if 'image' not in result:
|
148 |
+
res.message = 'failed to upscale/restore image'
|
149 |
+
else:
|
150 |
+
res.image = Image.open(io.BytesIO(base64.b64decode(result['image'])))
|
151 |
+
return res
|
152 |
+
|
153 |
+
|
154 |
+
def interrogate_image(res: Result, tag: str = None):
|
155 |
+
caption = ''
|
156 |
+
tags = []
|
157 |
+
for model in options.process.interrogate_model:
|
158 |
+
json = util.Map({ 'image': encode(res.image), 'model': model })
|
159 |
+
result = sdapi.postsync('/sdapi/v1/interrogate', json)
|
160 |
+
if model == 'clip':
|
161 |
+
caption = result.caption if 'caption' in result else ''
|
162 |
+
caption = caption.split(',')[0].replace(' a ', ' ').strip()
|
163 |
+
if tag is not None:
|
164 |
+
caption = res.tag + ', ' + caption
|
165 |
+
if model == 'deepdanbooru':
|
166 |
+
tag = result.caption if 'caption' in result else ''
|
167 |
+
tags = tag.split(',')
|
168 |
+
tags = [t.replace('(', '').replace(')', '').replace('\\', '').split(':')[0].strip() for t in tags]
|
169 |
+
if tag is not None:
|
170 |
+
for t in res.tag.split(',')[::-1]:
|
171 |
+
tags.insert(0, t.strip())
|
172 |
+
pos = 0 if len(tags) == 0 else 1
|
173 |
+
tags.insert(pos, caption.split(' ')[1])
|
174 |
+
tags = [t for t in tags if len(t) > 2]
|
175 |
+
if len(tags) > options.process.tag_limit:
|
176 |
+
tags = tags[:options.process.tag_limit]
|
177 |
+
res.caption = caption
|
178 |
+
res.tags = tags
|
179 |
+
res.ops.append('interrogate')
|
180 |
+
return res
|
181 |
+
|
182 |
+
|
183 |
+
def resize_image(res: Result):
|
184 |
+
resized = res.image
|
185 |
+
resized.thumbnail((options.process.target_size, options.process.target_size), Image.Resampling.HAMMING)
|
186 |
+
res.image = resized
|
187 |
+
res.ops.append('resize')
|
188 |
+
return res
|
189 |
+
|
190 |
+
|
191 |
+
def square_image(res: Result):
|
192 |
+
size = max(res.image.width, res.image.height)
|
193 |
+
squared = Image.new('RGB', (size, size))
|
194 |
+
squared.paste(res.image, ((size - res.image.width) // 2, (size - res.image.height) // 2))
|
195 |
+
res.image = squared
|
196 |
+
res.ops.append('square')
|
197 |
+
return res
|
198 |
+
|
199 |
+
|
200 |
+
def process_face(res: Result):
|
201 |
+
res.ops.append('face')
|
202 |
+
global face_model
|
203 |
+
if face_model is None:
|
204 |
+
face_model = mp.solutions.face_detection.FaceDetection(min_detection_confidence=options.process.face_score, model_selection=options.process.face_model)
|
205 |
+
results = face_model.process(np.array(res.image))
|
206 |
+
if results.detections is None:
|
207 |
+
res.message = 'no face detected'
|
208 |
+
res.image = None
|
209 |
+
return res
|
210 |
+
box = results.detections[0].location_data.relative_bounding_box
|
211 |
+
if box.xmin < 0 or box.ymin < 0 or (box.width - box.xmin) > 1 or (box.height - box.ymin) > 1:
|
212 |
+
res.message = 'face out of frame'
|
213 |
+
res.image = None
|
214 |
+
return res
|
215 |
+
x = max(0, (box.xmin - options.process.face_pad / 2) * res.image.width)
|
216 |
+
y = max(0, (box.ymin - options.process.face_pad / 2)* res.image.height)
|
217 |
+
w = min(res.image.width, (box.width + options.process.face_pad) * res.image.width)
|
218 |
+
h = min(res.image.height, (box.height + options.process.face_pad) * res.image.height)
|
219 |
+
x = max(0, x)
|
220 |
+
res.image = res.image.crop((x, y, x + w, y + h))
|
221 |
+
return res
|
222 |
+
|
223 |
+
|
224 |
+
def process_body(res: Result):
|
225 |
+
res.ops.append('body')
|
226 |
+
global body_model
|
227 |
+
if body_model is None:
|
228 |
+
body_model = mp.solutions.pose.Pose(static_image_mode=True, min_detection_confidence=options.process.body_score, model_complexity=options.process.body_model)
|
229 |
+
results = body_model.process(np.array(res.image))
|
230 |
+
if results.pose_landmarks is None:
|
231 |
+
res.message = 'no body detected'
|
232 |
+
res.image = None
|
233 |
+
return res
|
234 |
+
x0 = [res.image.width * (i.x - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
235 |
+
y0 = [res.image.height * (i.y - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
236 |
+
x1 = [res.image.width * (i.x + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
237 |
+
y1 = [res.image.height * (i.y + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
238 |
+
if len(x0) < options.process.body_parts:
|
239 |
+
res.message = f'insufficient body parts detected: {len(x0)}'
|
240 |
+
res.image = None
|
241 |
+
return res
|
242 |
+
res.image = res.image.crop((max(0, min(x0)), max(0, min(y0)), min(res.image.width, max(x1)), min(res.image.height, max(y1))))
|
243 |
+
return res
|
244 |
+
|
245 |
+
|
246 |
+
def process_original(res: Result):
|
247 |
+
res.ops.append('original')
|
248 |
+
return res
|
249 |
+
|
250 |
+
|
251 |
+
def save_image(res: Result, folder: str):
|
252 |
+
if res.image is None or folder is None:
|
253 |
+
return res
|
254 |
+
all_images_by_type[res.type] = all_images_by_type.get(res.type, 0) + 1
|
255 |
+
res.basename = os.path.basename(res.input).split('.')[0]
|
256 |
+
res.basename = str(all_images_by_type[res.type]).rjust(3, '0') + '-' + res.type + '-' + res.basename
|
257 |
+
res.basename = os.path.join(folder, res.basename)
|
258 |
+
res.output = res.basename + options.process.format
|
259 |
+
res.image.save(res.output)
|
260 |
+
res.image.close()
|
261 |
+
res.ops.append('save')
|
262 |
+
return res
|
263 |
+
|
264 |
+
|
265 |
+
def file(filename: str, folder: str, tag = None, requested = []):
|
266 |
+
# initialize result dict
|
267 |
+
res = Result(fn = filename, typ='unknown', tag=tag, requested = requested)
|
268 |
+
# open image
|
269 |
+
try:
|
270 |
+
register_heif_opener()
|
271 |
+
res.image = Image.open(filename)
|
272 |
+
if res.image.mode == 'RGBA':
|
273 |
+
res.image = res.image.convert('RGB')
|
274 |
+
res.image = ImageOps.exif_transpose(res.image) # rotate image according to EXIF orientation
|
275 |
+
except Exception as e:
|
276 |
+
res.message = f'error opening: {e}'
|
277 |
+
return res
|
278 |
+
# primary steps
|
279 |
+
if 'face' in requested:
|
280 |
+
res.type = 'face'
|
281 |
+
res = process_face(res)
|
282 |
+
elif 'body' in requested:
|
283 |
+
res.type = 'body'
|
284 |
+
res = process_body(res)
|
285 |
+
elif 'original' in requested:
|
286 |
+
res.type = 'original'
|
287 |
+
res = process_original(res)
|
288 |
+
# validation steps
|
289 |
+
if res.image is None:
|
290 |
+
return res
|
291 |
+
if 'blur' in requested:
|
292 |
+
res.ops.append('blur')
|
293 |
+
val = detect_blur(res.image)
|
294 |
+
if val > options.process.blur_score:
|
295 |
+
res.message = f'blur check failed: {val}'
|
296 |
+
res.image = None
|
297 |
+
if 'range' in requested:
|
298 |
+
res.ops.append('range')
|
299 |
+
val = detect_dynamicrange(res.image)
|
300 |
+
if val < options.process.range_score:
|
301 |
+
res.message = f'dynamic range check failed: {val}'
|
302 |
+
res.image = None
|
303 |
+
if 'similarity' in requested:
|
304 |
+
res.ops.append('similarity')
|
305 |
+
val = detect_simmilar(res.image)
|
306 |
+
if val > options.process.similarity_score:
|
307 |
+
res.message = f'dynamic range check failed: {val}'
|
308 |
+
res.image = None
|
309 |
+
if res.image is None:
|
310 |
+
return res
|
311 |
+
# post processing steps
|
312 |
+
res = upscale_restore_image(res, 'upscale' in requested, 'restore' in requested)
|
313 |
+
if res.image.width < options.process.target_size or res.image.height < options.process.target_size:
|
314 |
+
res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
|
315 |
+
res.image = None
|
316 |
+
return res
|
317 |
+
if 'interrogate' in requested:
|
318 |
+
res = interrogate_image(res, tag)
|
319 |
+
if 'resize' in requested:
|
320 |
+
res = resize_image(res)
|
321 |
+
if 'square' in requested:
|
322 |
+
res = square_image(res)
|
323 |
+
if 'segment' in requested:
|
324 |
+
res = segmentation(res)
|
325 |
+
# finally save image
|
326 |
+
res = save_image(res, folder)
|
327 |
+
return res
|
cli/random.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompts": [
|
3 |
+
"<style> of <embedding> <place>, high detailed, by <artist>, <suffix>"
|
4 |
+
],
|
5 |
+
"negative": [
|
6 |
+
"watermark, fog, clouds, blurry, duplicate, deformed, mutation"
|
7 |
+
],
|
8 |
+
"places": [
|
9 |
+
"standing in the city", "on a spaceship", "in fantasy landscape", "on a shore", "in a forest", "in winter wonderland"
|
10 |
+
],
|
11 |
+
"embeddings": [
|
12 |
+
"man", "man next to a beautiful girl", "man next to a car", "beautiful girl", "sexy naked girl", "cute girl holding a flower", "beautiful robot",
|
13 |
+
"young korean girl with medium-length white hair", "monster", "pin up girl",
|
14 |
+
"man vlado", "beutiful girl ana", "man lee", "beautiful girl abby"
|
15 |
+
],
|
16 |
+
"artists": [
|
17 |
+
"John Salminen", "Greg Rutkowski", "Akihiko Yoshida", "Alejandro Burdisio", "Artgerm", "Patrick Brown", "Walt Disney", "Neal Adams", "Jeremy Chong",
|
18 |
+
"Chris Rallis", "Roy Lichtenstein", "Claude Monet", "Jon Whitcomb", "Pablo Picasso", "Raymond Leech", "Tom Lovell", "Noriyoshi Ohrai", "Shingei",
|
19 |
+
"Helmut Newton", "Maciej Kuciara", "Daniel F. Gerhartz", "Stephan Martiniรจre", "Magali Villeneuve", "Carne Griffiths", "Alberto Seveso",
|
20 |
+
"Vincent Van Gogh", "WLOP", "Frank Xavier Leyendecker", "Peter Lindbergh", "Nick Gentry", "Howard Chandler Christy", "Raphael", "Henri Matisse"
|
21 |
+
],
|
22 |
+
"styles": [
|
23 |
+
"illustration", "painting", "portrait", "photograph", "drawing", "sketch", "pencil sketch", "3d render", "cartoon", "anime", "scribbles", "pop art",
|
24 |
+
"ink painting", "steampunk illustration", "dc comics illustration", "marvel comics", "vray render", "photoillustration", "pixar", "marble sculpture",
|
25 |
+
"bronze sculpture", "christmas theme"
|
26 |
+
],
|
27 |
+
"suffixes": [
|
28 |
+
"cinematic lighting", "artstation", "fineart", "cinematic", "photorealistic", "soft light", "sharp focus", "bokeh", "dreamlike", "semirealism",
|
29 |
+
"colorful", "black and white", "intricate", "elegant"
|
30 |
+
]
|
31 |
+
}
|
cli/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp
|
2 |
+
mediapipe
|
3 |
+
extcolors
|
4 |
+
colormap
|
5 |
+
filetype
|
6 |
+
albumentations
|
7 |
+
matplotlib
|
cli/run-benchmark.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
sd api txt2img benchmark
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import asyncio
|
7 |
+
import base64
|
8 |
+
import io
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import argparse
|
12 |
+
from PIL import Image
|
13 |
+
import sdapi
|
14 |
+
from util import Map, log
|
15 |
+
|
16 |
+
|
17 |
+
oom = 0
|
18 |
+
args = None
|
19 |
+
options = None
|
20 |
+
|
21 |
+
|
22 |
+
async def txt2img():
|
23 |
+
t0 = time.perf_counter()
|
24 |
+
data = {}
|
25 |
+
try:
|
26 |
+
data = await sdapi.post('/sdapi/v1/txt2img', options)
|
27 |
+
except Exception:
|
28 |
+
return -1
|
29 |
+
if 'error' in data:
|
30 |
+
return -1
|
31 |
+
if 'info' in data:
|
32 |
+
info = Map(json.loads(data['info']))
|
33 |
+
else:
|
34 |
+
return 0
|
35 |
+
log.debug({ 'info': info })
|
36 |
+
if options['batch_size'] != len(data['images']):
|
37 |
+
log.error({ 'requested': options['batch_size'], 'received': len(data['images']) })
|
38 |
+
return 0
|
39 |
+
for i in range(len(data['images'])):
|
40 |
+
data['images'][i] = Image.open(io.BytesIO(base64.b64decode(data['images'][i].split(',',1)[0])))
|
41 |
+
if args.save:
|
42 |
+
fn = os.path.join(args.save, f'benchmark-{i}-{len(data["images"])}.png')
|
43 |
+
data["images"][i].save(fn)
|
44 |
+
log.debug({ 'save': fn })
|
45 |
+
log.debug({ "images": data["images"] })
|
46 |
+
t1 = time.perf_counter()
|
47 |
+
return t1 - t0
|
48 |
+
|
49 |
+
|
50 |
+
def memstats():
|
51 |
+
mem = sdapi.getsync('/sdapi/v1/memory')
|
52 |
+
cpu = mem.get('ram', 'unavailable')
|
53 |
+
gpu = mem.get('cuda', 'unavailable')
|
54 |
+
if 'active' in gpu:
|
55 |
+
gpu['session'] = gpu.pop('active')
|
56 |
+
if 'reserved' in gpu:
|
57 |
+
gpu.pop('allocated')
|
58 |
+
gpu.pop('reserved')
|
59 |
+
gpu.pop('inactive')
|
60 |
+
if 'events' in gpu:
|
61 |
+
global oom # pylint: disable=global-statement
|
62 |
+
oom = gpu['events']['oom']
|
63 |
+
gpu.pop('events')
|
64 |
+
return cpu, gpu
|
65 |
+
|
66 |
+
|
67 |
+
def gb(val: float):
|
68 |
+
return round(val / 1024 / 1024 / 1024, 2)
|
69 |
+
|
70 |
+
|
71 |
+
async def main():
|
72 |
+
sdapi.quiet = True
|
73 |
+
await sdapi.session()
|
74 |
+
await sdapi.interrupt()
|
75 |
+
ver = await sdapi.get("/sdapi/v1/version")
|
76 |
+
log.info({ 'version': ver})
|
77 |
+
platform = await sdapi.get("/sdapi/v1/platform")
|
78 |
+
log.info({ 'platform': platform })
|
79 |
+
opts = await sdapi.get('/sdapi/v1/options')
|
80 |
+
opts = Map(opts)
|
81 |
+
log.info({ 'model': opts.sd_model_checkpoint })
|
82 |
+
cpu, gpu = memstats()
|
83 |
+
log.info({ 'system': { 'cpu': cpu, 'gpu': gpu }})
|
84 |
+
batch = [1, 1, 2, 4, 8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256]
|
85 |
+
batch = [b for b in batch if b <= args.maxbatch]
|
86 |
+
log.info({"batch-sizes": batch})
|
87 |
+
for i in range(len(batch)):
|
88 |
+
if oom > 0:
|
89 |
+
continue
|
90 |
+
options['batch_size'] = batch[i]
|
91 |
+
warmup = await txt2img()
|
92 |
+
ts = await txt2img()
|
93 |
+
if i == 0:
|
94 |
+
ts += warmup
|
95 |
+
if ts > 0.01: # cannot be faster than 10ms per run
|
96 |
+
await asyncio.sleep(0)
|
97 |
+
cpu, gpu = memstats()
|
98 |
+
if i == 0:
|
99 |
+
log.info({ 'warmup': round(ts, 2) })
|
100 |
+
else:
|
101 |
+
peak = gpu['system']['used'] # gpu['session']['peak'] if 'session' in gpu else 0
|
102 |
+
log.info({ 'batch': batch[i], 'its': round(options.steps / (ts / batch[i]), 2), 'img': round(ts / batch[i], 2), 'wall': round(ts, 2), 'peak': gb(peak), 'oom': oom > 0 })
|
103 |
+
else:
|
104 |
+
await asyncio.sleep(10)
|
105 |
+
cpu, gpu = memstats()
|
106 |
+
log.info({ 'batch': batch[i], 'result': 'error', 'gpu': gpu, 'oom': oom > 0 })
|
107 |
+
break
|
108 |
+
if oom > 0:
|
109 |
+
log.info({ 'benchmark': 'ended with oom so you should probably restart your automatic server now' })
|
110 |
+
await sdapi.close()
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
log.info({ 'run-benchmark' })
|
115 |
+
parser = argparse.ArgumentParser(description = 'run-benchmark')
|
116 |
+
parser.add_argument("--steps", type=int, default=50, required=False, help="steps")
|
117 |
+
parser.add_argument("--sampler", type=str, default='Euler a', required=False, help="Use specific sampler")
|
118 |
+
parser.add_argument("--prompt", type=str, default='photo of two dice on a table', required=False, help="prompt")
|
119 |
+
parser.add_argument("--negative", type=str, default='foggy, blurry', required=False, help="prompt")
|
120 |
+
parser.add_argument("--maxbatch", type=int, default=16, required=False, help="max batch size")
|
121 |
+
parser.add_argument("--width", type=int, default=512, required=False, help="width")
|
122 |
+
parser.add_argument("--height", type=int, default=512, required=False, help="height")
|
123 |
+
parser.add_argument('--debug', default = False, action='store_true', help = 'debug logging')
|
124 |
+
parser.add_argument('--taesd', default = False, action='store_true', help = 'use taesd as vae')
|
125 |
+
parser.add_argument("--save", type=str, default='', required=False, help="save images to folder")
|
126 |
+
args = parser.parse_args()
|
127 |
+
if args.debug:
|
128 |
+
log.setLevel('DEBUG')
|
129 |
+
options = Map(
|
130 |
+
{
|
131 |
+
"prompt": args.prompt,
|
132 |
+
"negative_prompt": args.negative,
|
133 |
+
"steps": args.steps,
|
134 |
+
"sampler_name": args.sampler,
|
135 |
+
"width": args.width,
|
136 |
+
"height": args.height,
|
137 |
+
"full_quality": not args.taesd,
|
138 |
+
"cfg_scale": 0,
|
139 |
+
"batch_size": 1,
|
140 |
+
"n_iter": 1,
|
141 |
+
"seed": -1,
|
142 |
+
}
|
143 |
+
)
|
144 |
+
log.info({"options": options})
|
145 |
+
try:
|
146 |
+
asyncio.run(main())
|
147 |
+
except KeyboardInterrupt:
|
148 |
+
log.warning({ 'interrupted': 'keyboard request' })
|
149 |
+
sdapi.interruptsync()
|
cli/sdapi.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#pylint: disable=redefined-outer-name
|
3 |
+
"""
|
4 |
+
helper methods that creates HTTP session with managed connection pool
|
5 |
+
provides async HTTP get/post methods and several helper methods
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import ssl
|
12 |
+
import base64
|
13 |
+
import asyncio
|
14 |
+
import logging
|
15 |
+
import aiohttp
|
16 |
+
import requests
|
17 |
+
import urllib3
|
18 |
+
from PIL import Image
|
19 |
+
from util import Map, log
|
20 |
+
from rich import print # pylint: disable=redefined-builtin
|
21 |
+
|
22 |
+
|
23 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") # api url root
|
24 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
25 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
26 |
+
|
27 |
+
use_session = True
|
28 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
29 |
+
ssl.create_default_context = ssl._create_unverified_context # pylint: disable=protected-access
|
30 |
+
timeout = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
|
31 |
+
sess = None
|
32 |
+
quiet = False
|
33 |
+
BaseThreadPolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy
|
34 |
+
|
35 |
+
|
36 |
+
class AnyThreadEventLoopPolicy(BaseThreadPolicy):
|
37 |
+
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
38 |
+
try:
|
39 |
+
return super().get_event_loop()
|
40 |
+
except (RuntimeError, AssertionError):
|
41 |
+
loop = self.new_event_loop()
|
42 |
+
self.set_event_loop(loop)
|
43 |
+
return loop
|
44 |
+
|
45 |
+
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
46 |
+
|
47 |
+
|
48 |
+
def authsync():
|
49 |
+
if sd_username is not None and sd_password is not None:
|
50 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
51 |
+
return None
|
52 |
+
|
53 |
+
|
54 |
+
def auth():
|
55 |
+
if sd_username is not None and sd_password is not None:
|
56 |
+
return aiohttp.BasicAuth(sd_username, sd_password)
|
57 |
+
return None
|
58 |
+
|
59 |
+
|
60 |
+
async def result(req):
|
61 |
+
if req.status != 200:
|
62 |
+
if not quiet:
|
63 |
+
log.error({ 'request error': req.status, 'reason': req.reason, 'url': req.url })
|
64 |
+
if not use_session and sess is not None:
|
65 |
+
await sess.close()
|
66 |
+
return Map({ 'error': req.status, 'reason': req.reason, 'url': req.url })
|
67 |
+
else:
|
68 |
+
json = await req.json()
|
69 |
+
if isinstance(json, list):
|
70 |
+
res = json
|
71 |
+
elif json is None:
|
72 |
+
res = {}
|
73 |
+
else:
|
74 |
+
res = Map(json)
|
75 |
+
log.debug({ 'request': req.status, 'url': req.url, 'reason': req.reason })
|
76 |
+
return res
|
77 |
+
|
78 |
+
|
79 |
+
def resultsync(req: requests.Response):
|
80 |
+
if req.status_code != 200:
|
81 |
+
if not quiet:
|
82 |
+
log.error({ 'request error': req.status_code, 'reason': req.reason, 'url': req.url })
|
83 |
+
return Map({ 'error': req.status_code, 'reason': req.reason, 'url': req.url })
|
84 |
+
else:
|
85 |
+
json = req.json()
|
86 |
+
if isinstance(json, list):
|
87 |
+
res = json
|
88 |
+
elif json is None:
|
89 |
+
res = {}
|
90 |
+
else:
|
91 |
+
res = Map(json)
|
92 |
+
log.debug({ 'request': req.status_code, 'url': req.url, 'reason': req.reason })
|
93 |
+
return res
|
94 |
+
|
95 |
+
|
96 |
+
async def get(endpoint: str, json: dict = None):
|
97 |
+
global sess # pylint: disable=global-statement
|
98 |
+
sess = sess if sess is not None else await session()
|
99 |
+
try:
|
100 |
+
async with sess.get(url=endpoint, json=json, verify_ssl=False) as req:
|
101 |
+
res = await result(req)
|
102 |
+
return res
|
103 |
+
except Exception as err:
|
104 |
+
log.error({ 'session': err })
|
105 |
+
return {}
|
106 |
+
|
107 |
+
|
108 |
+
def getsync(endpoint: str, json: dict = None):
|
109 |
+
try:
|
110 |
+
req = requests.get(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
|
111 |
+
res = resultsync(req)
|
112 |
+
return res
|
113 |
+
except Exception as err:
|
114 |
+
log.error({ 'session': err })
|
115 |
+
return {}
|
116 |
+
|
117 |
+
|
118 |
+
async def post(endpoint: str, json: dict = None):
|
119 |
+
global sess # pylint: disable=global-statement
|
120 |
+
# sess = sess if sess is not None else await session()
|
121 |
+
if sess and not sess.closed:
|
122 |
+
await sess.close()
|
123 |
+
sess = await session()
|
124 |
+
try:
|
125 |
+
async with sess.post(url=endpoint, json=json, verify_ssl=False) as req:
|
126 |
+
res = await result(req)
|
127 |
+
return res
|
128 |
+
except Exception as err:
|
129 |
+
log.error({ 'session': err })
|
130 |
+
return {}
|
131 |
+
|
132 |
+
|
133 |
+
def postsync(endpoint: str, json: dict = None):
|
134 |
+
req = requests.post(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
|
135 |
+
res = resultsync(req)
|
136 |
+
return res
|
137 |
+
|
138 |
+
|
139 |
+
async def interrupt():
|
140 |
+
res = await get('/sdapi/v1/progress?skip_current_image=true')
|
141 |
+
if 'state' in res and res.state.job_count > 0:
|
142 |
+
log.debug({ 'interrupt': res.state })
|
143 |
+
res = await post('/sdapi/v1/interrupt')
|
144 |
+
await asyncio.sleep(1)
|
145 |
+
return res
|
146 |
+
else:
|
147 |
+
log.debug({ 'interrupt': 'idle' })
|
148 |
+
return { 'interrupt': 'idle' }
|
149 |
+
|
150 |
+
|
151 |
+
def interruptsync():
|
152 |
+
res = getsync('/sdapi/v1/progress?skip_current_image=true')
|
153 |
+
if 'state' in res and res.state.job_count > 0:
|
154 |
+
log.debug({ 'interrupt': res.state })
|
155 |
+
res = postsync('/sdapi/v1/interrupt')
|
156 |
+
return res
|
157 |
+
else:
|
158 |
+
log.debug({ 'interrupt': 'idle' })
|
159 |
+
return { 'interrupt': 'idle' }
|
160 |
+
|
161 |
+
|
162 |
+
async def progress():
|
163 |
+
res = await get('/sdapi/v1/progress?skip_current_image=false')
|
164 |
+
try:
|
165 |
+
if res is not None and res.get('current_image', None) is not None:
|
166 |
+
res.current_image = Image.open(io.BytesIO(base64.b64decode(res['current_image'])))
|
167 |
+
except Exception:
|
168 |
+
pass
|
169 |
+
log.debug({ 'progress': res })
|
170 |
+
return res
|
171 |
+
|
172 |
+
|
173 |
+
def progresssync():
|
174 |
+
res = getsync('/sdapi/v1/progress?skip_current_image=true')
|
175 |
+
log.debug({ 'progress': res })
|
176 |
+
return res
|
177 |
+
|
178 |
+
|
179 |
+
def get_log():
|
180 |
+
res = getsync('/sdapi/v1/log')
|
181 |
+
for line in res:
|
182 |
+
log.debug(line)
|
183 |
+
return res
|
184 |
+
|
185 |
+
|
186 |
+
def get_info():
|
187 |
+
import time
|
188 |
+
t0 = time.time()
|
189 |
+
res = getsync('/sdapi/v1/system-info/status?full=true&refresh=true')
|
190 |
+
t1 = time.time()
|
191 |
+
print({ 'duration': 1000 * round(t1-t0, 3), **res })
|
192 |
+
return res
|
193 |
+
|
194 |
+
|
195 |
+
def options():
|
196 |
+
opts = getsync('/sdapi/v1/options')
|
197 |
+
flags = getsync('/sdapi/v1/cmd-flags')
|
198 |
+
return { 'options': opts, 'flags': flags }
|
199 |
+
|
200 |
+
|
201 |
+
def shutdown():
|
202 |
+
try:
|
203 |
+
postsync('/sdapi/v1/shutdown')
|
204 |
+
except Exception as e:
|
205 |
+
log.info({ 'shutdown': e })
|
206 |
+
|
207 |
+
|
208 |
+
async def session():
|
209 |
+
global sess # pylint: disable=global-statement
|
210 |
+
time = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
|
211 |
+
sess = aiohttp.ClientSession(timeout = time, base_url = sd_url, auth=auth())
|
212 |
+
log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
|
213 |
+
"""
|
214 |
+
sess = await aiohttp.ClientSession(timeout = timeout).__aenter__()
|
215 |
+
try:
|
216 |
+
async with sess.get(url = f'{sd_url}/') as req:
|
217 |
+
log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
|
218 |
+
except Exception as e:
|
219 |
+
log.error({ 'sdapi': e })
|
220 |
+
await asyncio.sleep(0)
|
221 |
+
await sess.__aexit__(None, None, None)
|
222 |
+
sess = None
|
223 |
+
return sess
|
224 |
+
"""
|
225 |
+
return sess
|
226 |
+
|
227 |
+
|
228 |
+
async def close():
|
229 |
+
if sess is not None:
|
230 |
+
await asyncio.sleep(0)
|
231 |
+
await sess.close()
|
232 |
+
await sess.__aexit__(None, None, None)
|
233 |
+
log.debug({ 'sdapi': 'session closed', 'endpoint': sd_url })
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
sys.argv.pop(0)
|
238 |
+
log.setLevel(logging.DEBUG)
|
239 |
+
if 'interrupt' in sys.argv:
|
240 |
+
asyncio.run(interrupt())
|
241 |
+
elif 'progress' in sys.argv:
|
242 |
+
asyncio.run(progress())
|
243 |
+
elif 'progresssync' in sys.argv:
|
244 |
+
progresssync()
|
245 |
+
elif 'options' in sys.argv:
|
246 |
+
opt = options()
|
247 |
+
log.debug({ 'options' })
|
248 |
+
import json
|
249 |
+
print(json.dumps(opt['options'], indent = 2))
|
250 |
+
log.debug({ 'cmd-flags' })
|
251 |
+
print(json.dumps(opt['flags'], indent = 2))
|
252 |
+
elif 'log' in sys.argv:
|
253 |
+
get_log()
|
254 |
+
elif 'info' in sys.argv:
|
255 |
+
get_info()
|
256 |
+
elif 'shutdown' in sys.argv:
|
257 |
+
shutdown()
|
258 |
+
else:
|
259 |
+
res = getsync(sys.argv[0])
|
260 |
+
print(res)
|
261 |
+
asyncio.run(close(), debug=True)
|
262 |
+
asyncio.run(asyncio.sleep(0.5))
|
cli/simple-img2img.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import time
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import requests
|
9 |
+
import urllib3
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
13 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
14 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
15 |
+
|
16 |
+
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
17 |
+
log = logging.getLogger(__name__)
|
18 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
19 |
+
|
20 |
+
options = {
|
21 |
+
"save_images": False,
|
22 |
+
"send_images": True,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def auth():
|
27 |
+
if sd_username is not None and sd_password is not None:
|
28 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
29 |
+
return None
|
30 |
+
|
31 |
+
|
32 |
+
def post(endpoint: str, dct: dict = None):
|
33 |
+
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
34 |
+
if req.status_code != 200:
|
35 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
36 |
+
else:
|
37 |
+
return req.json()
|
38 |
+
|
39 |
+
|
40 |
+
def encode(f):
|
41 |
+
image = Image.open(f)
|
42 |
+
if image.mode == 'RGBA':
|
43 |
+
image = image.convert('RGB')
|
44 |
+
with io.BytesIO() as stream:
|
45 |
+
image.save(stream, 'JPEG')
|
46 |
+
image.close()
|
47 |
+
values = stream.getvalue()
|
48 |
+
encoded = base64.b64encode(values).decode()
|
49 |
+
return encoded
|
50 |
+
|
51 |
+
|
52 |
+
def generate(args): # pylint: disable=redefined-outer-name
|
53 |
+
t0 = time.time()
|
54 |
+
if args.model is not None:
|
55 |
+
post('/sdapi/v1/options', { 'sd_model_checkpoint': args.model })
|
56 |
+
post('/sdapi/v1/reload-checkpoint') # needed if running in api-only to trigger new model load
|
57 |
+
options['prompt'] = args.prompt
|
58 |
+
options['negative_prompt'] = args.negative
|
59 |
+
options['steps'] = int(args.steps)
|
60 |
+
options['seed'] = int(args.seed)
|
61 |
+
options['sampler_name'] = args.sampler
|
62 |
+
options['init_images'] = [encode(args.init)]
|
63 |
+
image = Image.open(args.init)
|
64 |
+
options['width'] = image.width
|
65 |
+
options['height'] = image.height
|
66 |
+
image.close()
|
67 |
+
if args.mask is not None:
|
68 |
+
options['mask'] = encode(args.mask)
|
69 |
+
data = post('/sdapi/v1/img2img', options)
|
70 |
+
t1 = time.time()
|
71 |
+
if 'images' in data:
|
72 |
+
for i in range(len(data['images'])):
|
73 |
+
b64 = data['images'][i].split(',',1)[0]
|
74 |
+
info = data['info']
|
75 |
+
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
76 |
+
log.info(f'received image: size={image.size} time={t1-t0:.2f} info="{info}"')
|
77 |
+
if args.output:
|
78 |
+
image.save(args.output)
|
79 |
+
log.info(f'image saved: size={image.size} filename={args.output}')
|
80 |
+
|
81 |
+
else:
|
82 |
+
log.warning(f'no images received: {data}')
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
parser = argparse.ArgumentParser(description = 'simple-img2img')
|
87 |
+
parser.add_argument('--init', required=True, help='init image')
|
88 |
+
parser.add_argument('--mask', required=False, help='mask image')
|
89 |
+
parser.add_argument('--prompt', required=False, default='', help='prompt text')
|
90 |
+
parser.add_argument('--negative', required=False, default='', help='negative prompt text')
|
91 |
+
parser.add_argument('--steps', required=False, default=20, help='number of steps')
|
92 |
+
parser.add_argument('--seed', required=False, default=-1, help='initial seed')
|
93 |
+
parser.add_argument('--sampler', required=False, default='Euler a', help='sampler name')
|
94 |
+
parser.add_argument('--output', required=False, default=None, help='output image file')
|
95 |
+
parser.add_argument('--model', required=False, help='model name')
|
96 |
+
args = parser.parse_args()
|
97 |
+
log.info(f'img2img: {args}')
|
98 |
+
generate(args)
|
cli/simple-info.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import base64
|
5 |
+
import logging
|
6 |
+
import argparse
|
7 |
+
import requests
|
8 |
+
import urllib3
|
9 |
+
|
10 |
+
|
11 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
12 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
13 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
14 |
+
|
15 |
+
|
16 |
+
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
17 |
+
log = logging.getLogger(__name__)
|
18 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
19 |
+
|
20 |
+
|
21 |
+
def auth():
|
22 |
+
if sd_username is not None and sd_password is not None:
|
23 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
def get(endpoint: str, dct: dict = None):
|
28 |
+
req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
|
29 |
+
if req.status_code != 200:
|
30 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
31 |
+
else:
|
32 |
+
return req.json()
|
33 |
+
|
34 |
+
|
35 |
+
def post(endpoint: str, dct: dict = None):
|
36 |
+
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
37 |
+
if req.status_code != 200:
|
38 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
39 |
+
else:
|
40 |
+
return req.json()
|
41 |
+
|
42 |
+
|
43 |
+
def info(args): # pylint: disable=redefined-outer-name
|
44 |
+
t0 = time.time()
|
45 |
+
with open(args.input, 'rb') as f:
|
46 |
+
content = f.read()
|
47 |
+
data = post('/sdapi/v1/png-info', { 'image': base64.b64encode(content).decode() })
|
48 |
+
t1 = time.time()
|
49 |
+
log.info(f'received: {data} time={t1-t0:.2f}')
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
parser = argparse.ArgumentParser(description = 'simple-info')
|
54 |
+
parser.add_argument('--input', required=True, help='input image')
|
55 |
+
args = parser.parse_args()
|
56 |
+
log.info(f'info: {args}')
|
57 |
+
info(args)
|
cli/simple-mask.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import requests
|
9 |
+
import urllib3
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
14 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
15 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
16 |
+
|
17 |
+
|
18 |
+
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
19 |
+
log = logging.getLogger(__name__)
|
20 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
21 |
+
|
22 |
+
|
23 |
+
def auth():
|
24 |
+
if sd_username is not None and sd_password is not None:
|
25 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
26 |
+
return None
|
27 |
+
|
28 |
+
|
29 |
+
def get(endpoint: str, dct: dict = None):
|
30 |
+
req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
|
31 |
+
if req.status_code != 200:
|
32 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
33 |
+
else:
|
34 |
+
return req.json()
|
35 |
+
|
36 |
+
|
37 |
+
def post(endpoint: str, dct: dict = None):
|
38 |
+
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
39 |
+
if req.status_code != 200:
|
40 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
41 |
+
else:
|
42 |
+
return req.json()
|
43 |
+
|
44 |
+
|
45 |
+
def info(args): # pylint: disable=redefined-outer-name
|
46 |
+
t0 = time.time()
|
47 |
+
with open(args.input, 'rb') as f:
|
48 |
+
image = base64.b64encode(f.read()).decode()
|
49 |
+
if args.mask:
|
50 |
+
with open(args.mask, 'rb') as f:
|
51 |
+
mask = base64.b64encode(f.read()).decode()
|
52 |
+
else:
|
53 |
+
mask = None
|
54 |
+
options = get('/sdapi/v1/masking')
|
55 |
+
log.info(f'options: {options}')
|
56 |
+
req = {
|
57 |
+
'image': image,
|
58 |
+
'mask': mask,
|
59 |
+
'type': args.type or 'Composite',
|
60 |
+
'params': { 'auto_mask': 'Grayscale' if mask is None else None },
|
61 |
+
}
|
62 |
+
data = post('/sdapi/v1/mask', req)
|
63 |
+
t1 = time.time()
|
64 |
+
if 'mask' in data:
|
65 |
+
b64 = data['mask'].split(',',1)[0]
|
66 |
+
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
67 |
+
log.info(f'received image: size={image.size} time={t1-t0:.2f}')
|
68 |
+
if args.output:
|
69 |
+
image.save(args.output)
|
70 |
+
log.info(f'saved image: fn={args.output}')
|
71 |
+
else:
|
72 |
+
log.info(f'received: {data} time={t1-t0:.2f}')
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
parser = argparse.ArgumentParser(description = 'simple-info')
|
77 |
+
parser.add_argument('--input', required=True, help='input image')
|
78 |
+
parser.add_argument('--mask', required=False, help='input mask')
|
79 |
+
parser.add_argument('--type', required=False, help='output mask type')
|
80 |
+
parser.add_argument('--output', required=False, help='output image')
|
81 |
+
args = parser.parse_args()
|
82 |
+
log.info(f'info: {args}')
|
83 |
+
info(args)
|
cli/simple-preprocess.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import requests
|
9 |
+
import urllib3
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
14 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
15 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
16 |
+
|
17 |
+
|
18 |
+
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
19 |
+
log = logging.getLogger(__name__)
|
20 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
21 |
+
|
22 |
+
|
23 |
+
def auth():
|
24 |
+
if sd_username is not None and sd_password is not None:
|
25 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
26 |
+
return None
|
27 |
+
|
28 |
+
|
29 |
+
def get(endpoint: str, dct: dict = None):
|
30 |
+
req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
|
31 |
+
if req.status_code != 200:
|
32 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
33 |
+
else:
|
34 |
+
return req.json()
|
35 |
+
|
36 |
+
|
37 |
+
def post(endpoint: str, dct: dict = None):
|
38 |
+
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
39 |
+
if req.status_code != 200:
|
40 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
41 |
+
else:
|
42 |
+
return req.json()
|
43 |
+
|
44 |
+
|
45 |
+
def info(args): # pylint: disable=redefined-outer-name
|
46 |
+
t0 = time.time()
|
47 |
+
with open(args.input, 'rb') as f:
|
48 |
+
content = f.read()
|
49 |
+
models = get('/sdapi/v1/preprocessors')
|
50 |
+
log.info(f'models: {models}')
|
51 |
+
req = {
|
52 |
+
'model': args.model or 'Canny',
|
53 |
+
'image': base64.b64encode(content).decode(),
|
54 |
+
'config': { 'low_threshold': 50 },
|
55 |
+
}
|
56 |
+
data = post('/sdapi/v1/preprocess', req)
|
57 |
+
t1 = time.time()
|
58 |
+
if 'image' in data:
|
59 |
+
b64 = data['image'].split(',',1)[0]
|
60 |
+
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
61 |
+
log.info(f'received image: size={image.size} time={t1-t0:.2f}')
|
62 |
+
if args.output:
|
63 |
+
image.save(args.output)
|
64 |
+
log.info(f'saved image: fn={args.output}')
|
65 |
+
else:
|
66 |
+
log.info(f'received: {data} time={t1-t0:.2f}')
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
parser = argparse.ArgumentParser(description = 'simple-info')
|
71 |
+
parser.add_argument('--input', required=True, help='input image')
|
72 |
+
parser.add_argument('--model', required=True, help='preprocessing model')
|
73 |
+
parser.add_argument('--output', required=False, help='output image')
|
74 |
+
args = parser.parse_args()
|
75 |
+
log.info(f'info: {args}')
|
76 |
+
info(args)
|
cli/simple-txt2img.js
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env node
|
2 |
+
|
3 |
+
// simple nodejs script to test sdnext api
|
4 |
+
|
5 |
+
const fs = require('fs'); // eslint-disable-line no-undef
|
6 |
+
const process = require('process'); // eslint-disable-line no-undef
|
7 |
+
|
8 |
+
const sd_url = process.env.SDAPI_URL || 'http://127.0.0.1:7860';
|
9 |
+
const sd_username = process.env.SDAPI_USR;
|
10 |
+
const sd_password = process.env.SDAPI_PWD;
|
11 |
+
const sd_options = {
|
12 |
+
// first pass
|
13 |
+
prompt: 'city at night',
|
14 |
+
negative_prompt: 'foggy, blurry',
|
15 |
+
sampler_name: 'UniPC',
|
16 |
+
seed: -1,
|
17 |
+
steps: 20,
|
18 |
+
batch_size: 1,
|
19 |
+
n_iter: 1,
|
20 |
+
cfg_scale: 6,
|
21 |
+
width: 512,
|
22 |
+
height: 512,
|
23 |
+
// enable second pass
|
24 |
+
enable_hr: true,
|
25 |
+
// second pass: upscale
|
26 |
+
hr_upscaler: 'SCUNet GAN',
|
27 |
+
hr_scale: 2.0,
|
28 |
+
// second pass: hires
|
29 |
+
hr_force: true,
|
30 |
+
hr_second_pass_steps: 20,
|
31 |
+
hr_sampler_name: 'UniPC',
|
32 |
+
denoising_strength: 0.5,
|
33 |
+
// second pass: refiner
|
34 |
+
refiner_steps: 5,
|
35 |
+
refiner_start: 0.8,
|
36 |
+
refiner_prompt: '',
|
37 |
+
refiner_negative: '',
|
38 |
+
// api return options
|
39 |
+
save_images: false,
|
40 |
+
send_images: true,
|
41 |
+
};
|
42 |
+
|
43 |
+
async function main() {
|
44 |
+
const method = 'POST';
|
45 |
+
const headers = new Headers();
|
46 |
+
const body = JSON.stringify(sd_options);
|
47 |
+
headers.set('Content-Type', 'application/json');
|
48 |
+
if (sd_username && sd_password) headers.set({ Authorization: `Basic ${btoa('sd_username:sd_password')}` });
|
49 |
+
const res = await fetch(`${sd_url}/sdapi/v1/txt2img`, { method, headers, body });
|
50 |
+
if (res.status !== 200) {
|
51 |
+
console.log('Error', res.status);
|
52 |
+
} else {
|
53 |
+
const json = await res.json();
|
54 |
+
console.log('result:', json.info);
|
55 |
+
for (const i in json.images) { // eslint-disable-line guard-for-in
|
56 |
+
const f = `/tmp/test-{${i}.jpg`;
|
57 |
+
fs.writeFileSync(f, atob(json.images[i]), 'binary');
|
58 |
+
console.log('image saved:', f);
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
main();
|
cli/simple-txt2img.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import requests
|
9 |
+
import urllib3
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
13 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
14 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
15 |
+
|
16 |
+
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
17 |
+
log = logging.getLogger(__name__)
|
18 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
19 |
+
|
20 |
+
options = {
|
21 |
+
"save_images": False,
|
22 |
+
"send_images": True,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def auth():
|
27 |
+
if sd_username is not None and sd_password is not None:
|
28 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
29 |
+
return None
|
30 |
+
|
31 |
+
|
32 |
+
def post(endpoint: str, dct: dict = None):
|
33 |
+
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
34 |
+
if req.status_code != 200:
|
35 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
36 |
+
else:
|
37 |
+
return req.json()
|
38 |
+
|
39 |
+
|
40 |
+
def generate(args): # pylint: disable=redefined-outer-name
|
41 |
+
t0 = time.time()
|
42 |
+
if args.model is not None:
|
43 |
+
post('/sdapi/v1/options', { 'sd_model_checkpoint': args.model })
|
44 |
+
post('/sdapi/v1/reload-checkpoint') # needed if running in api-only to trigger new model load
|
45 |
+
options['prompt'] = args.prompt
|
46 |
+
options['negative_prompt'] = args.negative
|
47 |
+
options['steps'] = int(args.steps)
|
48 |
+
options['seed'] = int(args.seed)
|
49 |
+
options['sampler_name'] = args.sampler
|
50 |
+
options['width'] = int(args.width)
|
51 |
+
options['height'] = int(args.height)
|
52 |
+
data = post('/sdapi/v1/txt2img', options)
|
53 |
+
t1 = time.time()
|
54 |
+
if 'images' in data:
|
55 |
+
for i in range(len(data['images'])):
|
56 |
+
b64 = data['images'][i].split(',',1)[0]
|
57 |
+
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
58 |
+
info = data['info']
|
59 |
+
log.info(f'image received: size={image.size} time={t1-t0:.2f} info="{info}"')
|
60 |
+
if args.output:
|
61 |
+
image.save(args.output)
|
62 |
+
log.info(f'image saved: size={image.size} filename={args.output}')
|
63 |
+
else:
|
64 |
+
log.warning(f'no images received: {data}')
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
parser = argparse.ArgumentParser(description = 'simple-txt2img')
|
69 |
+
parser.add_argument('--prompt', required=False, default='', help='prompt text')
|
70 |
+
parser.add_argument('--negative', required=False, default='', help='negative prompt text')
|
71 |
+
parser.add_argument('--width', required=False, default=512, help='image width')
|
72 |
+
parser.add_argument('--height', required=False, default=512, help='image height')
|
73 |
+
parser.add_argument('--steps', required=False, default=20, help='number of steps')
|
74 |
+
parser.add_argument('--seed', required=False, default=-1, help='initial seed')
|
75 |
+
parser.add_argument('--sampler', required=False, default='Euler a', help='sampler name')
|
76 |
+
parser.add_argument('--output', required=False, default=None, help='output image file')
|
77 |
+
parser.add_argument('--model', required=False, help='model name')
|
78 |
+
args = parser.parse_args()
|
79 |
+
log.info(f'txt2img: {args}')
|
80 |
+
generate(args)
|
cli/simple-upscale.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import time
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import requests
|
9 |
+
import urllib3
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
13 |
+
sd_username = os.environ.get('SDAPI_USR', None)
|
14 |
+
sd_password = os.environ.get('SDAPI_PWD', None)
|
15 |
+
|
16 |
+
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
17 |
+
log = logging.getLogger(__name__)
|
18 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
19 |
+
|
20 |
+
|
21 |
+
def auth():
|
22 |
+
if sd_username is not None and sd_password is not None:
|
23 |
+
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
def get(endpoint: str, dct: dict = None):
|
28 |
+
req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
|
29 |
+
if req.status_code != 200:
|
30 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
31 |
+
else:
|
32 |
+
return req.json()
|
33 |
+
|
34 |
+
|
35 |
+
def post(endpoint: str, dct: dict = None):
|
36 |
+
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
37 |
+
if req.status_code != 200:
|
38 |
+
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
39 |
+
else:
|
40 |
+
return req.json()
|
41 |
+
|
42 |
+
|
43 |
+
def encode(f):
|
44 |
+
image = Image.open(f)
|
45 |
+
if image.mode == 'RGBA':
|
46 |
+
image = image.convert('RGB')
|
47 |
+
log.info(f'encoding image: {image}')
|
48 |
+
with io.BytesIO() as stream:
|
49 |
+
image.save(stream, 'JPEG')
|
50 |
+
image.close()
|
51 |
+
values = stream.getvalue()
|
52 |
+
encoded = base64.b64encode(values).decode()
|
53 |
+
return encoded
|
54 |
+
|
55 |
+
|
56 |
+
def upscale(args): # pylint: disable=redefined-outer-name
|
57 |
+
t0 = time.time()
|
58 |
+
# options['mask'] = encode(args.mask)
|
59 |
+
upscalers = get('/sdapi/v1/upscalers')
|
60 |
+
upscalers = [u['name'] for u in upscalers]
|
61 |
+
log.info(f'upscalers: {upscalers}')
|
62 |
+
options = {
|
63 |
+
"save_images": False,
|
64 |
+
"send_images": True,
|
65 |
+
'image': encode(args.input),
|
66 |
+
'upscaler_1': args.upscaler,
|
67 |
+
'resize_mode': 0, # rescale_by
|
68 |
+
'upscaling_resize': args.scale,
|
69 |
+
|
70 |
+
}
|
71 |
+
data = post('/sdapi/v1/extra-single-image', options)
|
72 |
+
t1 = time.time()
|
73 |
+
if 'image' in data:
|
74 |
+
b64 = data['image'].split(',',1)[0]
|
75 |
+
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
76 |
+
image.save(args.output)
|
77 |
+
log.info(f'received: image={image} file={args.output} time={t1-t0:.2f}')
|
78 |
+
else:
|
79 |
+
log.warning(f'no images received: {data}')
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
parser = argparse.ArgumentParser(description = 'simple-upscale')
|
84 |
+
parser.add_argument('--input', required=True, help='input image')
|
85 |
+
parser.add_argument('--output', required=True, help='output image')
|
86 |
+
parser.add_argument('--upscaler', required=False, default='Nearest', help='upscaler name')
|
87 |
+
parser.add_argument('--scale', required=False, default=2, help='upscaler scale')
|
88 |
+
args = parser.parse_args()
|
89 |
+
log.info(f'upscale: {args}')
|
90 |
+
upscale(args)
|
cli/torch-compile.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# pylint: disable=cell-var-from-loop
|
3 |
+
"""
|
4 |
+
Test Torch Dynamo functionality and backends
|
5 |
+
"""
|
6 |
+
import json
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torchvision.models import resnet18
|
12 |
+
|
13 |
+
|
14 |
+
print('torch:', torch.__version__)
|
15 |
+
try:
|
16 |
+
# must be imported explicitly or namespace is not found
|
17 |
+
import torch._dynamo as dynamo # pylint: disable=ungrouped-imports
|
18 |
+
except Exception as err:
|
19 |
+
print('torch without dynamo support', err)
|
20 |
+
|
21 |
+
|
22 |
+
N_ITERS = 20
|
23 |
+
torch._dynamo.config.verbose=True # pylint: disable=protected-access
|
24 |
+
warnings.filterwarnings('ignore', category=UserWarning) # disable those for now as many backends reports tons
|
25 |
+
# torch.set_float32_matmul_precision('high') # enable to test in fp32
|
26 |
+
|
27 |
+
|
28 |
+
def timed(fn): # returns the result of running `fn()` and the time it took for `fn()` to run in ms using CUDA events
|
29 |
+
start = torch.cuda.Event(enable_timing=True)
|
30 |
+
end = torch.cuda.Event(enable_timing=True)
|
31 |
+
start.record()
|
32 |
+
result = fn()
|
33 |
+
end.record()
|
34 |
+
torch.cuda.synchronize()
|
35 |
+
return result, start.elapsed_time(end)
|
36 |
+
|
37 |
+
|
38 |
+
def generate_data(b):
|
39 |
+
return (
|
40 |
+
torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
|
41 |
+
torch.randint(1000, (b,)).cuda(),
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def init_model():
|
46 |
+
return resnet18().to(torch.float32).cuda()
|
47 |
+
|
48 |
+
|
49 |
+
def evaluate(mod, val):
|
50 |
+
return mod(val)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
# first pass, dynamo is going to be slower as it compiles
|
55 |
+
model = init_model()
|
56 |
+
inp = generate_data(16)[0]
|
57 |
+
|
58 |
+
# repeat test
|
59 |
+
results = {}
|
60 |
+
times = []
|
61 |
+
print('eager initial eval:', timed(lambda: evaluate(model, inp))[1])
|
62 |
+
for _i in range(N_ITERS):
|
63 |
+
inp = generate_data(16)[0]
|
64 |
+
_res, time = timed(lambda: evaluate(model, inp)) # noqa: B023
|
65 |
+
times.append(time)
|
66 |
+
results['default'] = np.median(times)
|
67 |
+
|
68 |
+
print('dynamo available backends:', dynamo.list_backends())
|
69 |
+
for backend in dynamo.list_backends():
|
70 |
+
try:
|
71 |
+
# required before changing backends
|
72 |
+
torch._dynamo.reset() # pylint: disable=protected-access
|
73 |
+
eval_dyn = dynamo.optimize(backend)(evaluate)
|
74 |
+
print('dynamo initial eval:', backend, timed(lambda: eval_dyn(model, inp))[1]) # noqa: B023
|
75 |
+
times = []
|
76 |
+
for _i in range(N_ITERS):
|
77 |
+
inp = generate_data(16)[0]
|
78 |
+
_res, time = timed(lambda: eval_dyn(model, inp)) # noqa: B023
|
79 |
+
times.append(time)
|
80 |
+
results[backend] = np.median(times)
|
81 |
+
except Exception as err:
|
82 |
+
lines = str(err).split('\n')
|
83 |
+
print('dyanmo backend failed:', backend, lines[0]) # print just first error line as backtraces can be quite long
|
84 |
+
results[backend] = 'error'
|
85 |
+
|
86 |
+
# print stats
|
87 |
+
print(json.dumps(results, indent = 4))
|
88 |
+
|
89 |
+
"""
|
90 |
+
Reference: <https://github.com/pytorch/pytorch/blob/4f4b62e4a255708e928445b6502139d5962974fa/docs/source/dynamo/get-started.rst>
|
91 |
+
Training & Inference backends:
|
92 |
+
dynamo.optimize("inductor") - Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels
|
93 |
+
dynamo.optimize("aot_nvfuser") - nvFuser with AotAutograd
|
94 |
+
dynamo.optimize("aot_cudagraphs") - cudagraphs with AotAutograd
|
95 |
+
Inference-only backends:
|
96 |
+
dynamo.optimize("ofi") - Uses Torchscript optimize_for_inference
|
97 |
+
dynamo.optimize("fx2trt") - Uses Nvidia TensorRT for inference optimizations
|
98 |
+
dynamo.optimize("onnxrt") - Uses ONNXRT for inference on CPU/GPU
|
99 |
+
"""
|
cli/train.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
"""
|
4 |
+
Examples:
|
5 |
+
- sd15: train.py --type lora --tag girl --comments sdnext --input ~/generative/Input/mia --process original,interrogate,resize --name mia
|
6 |
+
- sdxl: train.py --type lora --tag girl --comments sdnext --input ~/generative/Input/mia --process original,interrogate,resize --precision fp32 --optimizer Adafactor --sdxl --name miaxl
|
7 |
+
- offline: train.py --type lora --tag girl --comments sdnext --input ~/generative/Input/mia --model /home/vlado/dev/sdnext/models/Stable-diffusion/sdxl/miaanimeSFWNSFWSDXL_v40.safetensors --dir /home/vlado/dev/sdnext/models/Lora/ --precision fp32 --optimizer Adafactor --sdxl --name miaxl
|
8 |
+
"""
|
9 |
+
|
10 |
+
# system imports
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
import gc
|
14 |
+
import sys
|
15 |
+
import json
|
16 |
+
import shutil
|
17 |
+
import pathlib
|
18 |
+
import asyncio
|
19 |
+
import logging
|
20 |
+
import tempfile
|
21 |
+
import argparse
|
22 |
+
|
23 |
+
# local imports
|
24 |
+
import util
|
25 |
+
import sdapi
|
26 |
+
import options
|
27 |
+
|
28 |
+
|
29 |
+
# globals
|
30 |
+
args = None
|
31 |
+
log = logging.getLogger('train')
|
32 |
+
valid_steps = ['original', 'face', 'body', 'blur', 'range', 'upscale', 'restore', 'interrogate', 'resize', 'square', 'segment']
|
33 |
+
log_file = os.path.join(os.path.dirname(__file__), 'train.log')
|
34 |
+
server_ok = False
|
35 |
+
|
36 |
+
# methods
|
37 |
+
|
38 |
+
def setup_logging():
|
39 |
+
from rich.theme import Theme
|
40 |
+
from rich.logging import RichHandler
|
41 |
+
from rich.console import Console
|
42 |
+
from rich.pretty import install as pretty_install
|
43 |
+
from rich.traceback import install as traceback_install
|
44 |
+
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
|
45 |
+
"traceback.border": "black",
|
46 |
+
"traceback.border.syntax_error": "black",
|
47 |
+
"inspect.value.border": "black",
|
48 |
+
}))
|
49 |
+
# logging.getLogger("urllib3").setLevel(logging.ERROR)
|
50 |
+
# logging.getLogger("httpx").setLevel(logging.ERROR)
|
51 |
+
level = logging.DEBUG if args.debug else logging.INFO
|
52 |
+
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', filename=log_file, filemode='a', encoding='utf-8', force=True)
|
53 |
+
log.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
|
54 |
+
pretty_install(console=console)
|
55 |
+
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
|
56 |
+
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
|
57 |
+
rh.set_name(level)
|
58 |
+
while log.hasHandlers() and len(log.handlers) > 0:
|
59 |
+
log.removeHandler(log.handlers[0])
|
60 |
+
log.addHandler(rh)
|
61 |
+
|
62 |
+
|
63 |
+
def mem_stats():
|
64 |
+
gc.collect()
|
65 |
+
import torch
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
with torch.no_grad():
|
68 |
+
torch.cuda.empty_cache()
|
69 |
+
with torch.cuda.device('cuda'):
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
torch.cuda.ipc_collect()
|
72 |
+
mem = util.get_memory()
|
73 |
+
peak = { 'active': mem['gpu-active']['peak'], 'allocated': mem['gpu-allocated']['peak'], 'reserved': mem['gpu-reserved']['peak'] }
|
74 |
+
log.debug(f"memory cpu: {mem.ram} gpu current: {mem.gpu} gpu peak: {peak}")
|
75 |
+
|
76 |
+
|
77 |
+
def parse_args():
|
78 |
+
global args # pylint: disable=global-statement
|
79 |
+
parser = argparse.ArgumentParser(description = 'SD.Next Train')
|
80 |
+
|
81 |
+
group_server = parser.add_argument_group('Server')
|
82 |
+
group_server.add_argument('--server', type=str, default='http://127.0.0.1:7860', required=False, help='server url, default: %(default)s')
|
83 |
+
group_server.add_argument('--user', type=str, default=None, required=False, help='server url, default: %(default)s')
|
84 |
+
group_server.add_argument('--password', type=str, default=None, required=False, help='server url, default: %(default)s')
|
85 |
+
group_server.add_argument('--dir', type=str, default=None, required=False, help='folder with trained networks, default: use server setting')
|
86 |
+
|
87 |
+
group_main = parser.add_argument_group('Main')
|
88 |
+
group_main.add_argument('--type', type=str, choices=['embedding', 'ti', 'lora', 'lyco', 'dreambooth', 'hypernetwork'], default=None, required=True, help='training type')
|
89 |
+
group_main.add_argument('--model', type=str, default='', required=False, help='base model to use for training, default: current loaded model')
|
90 |
+
group_main.add_argument('--name', type=str, default=None, required=True, help='output filename')
|
91 |
+
group_main.add_argument('--tag', type=str, default='person', required=False, help='primary tags, default: %(default)s')
|
92 |
+
group_main.add_argument('--comments', type=str, default='', required=False, help='comments to be added to trained model metadata, default: %(default)s')
|
93 |
+
|
94 |
+
group_data = parser.add_argument_group('Dataset')
|
95 |
+
group_data.add_argument('--input', type=str, default=None, required=True, help='input folder with training images')
|
96 |
+
group_data.add_argument('--interim', type=str, default='', required=False, help='where to store processed images, default is system temp/train')
|
97 |
+
group_data.add_argument('--process', type=str, default='original,interrogate,resize,square', required=False, help=f'list of possible processing steps: {valid_steps}, default: %(default)s')
|
98 |
+
|
99 |
+
group_train = parser.add_argument_group('Train')
|
100 |
+
group_train.add_argument('--gradient', type=int, default=1, required=False, help='gradient accumulation steps, default: %(default)s')
|
101 |
+
group_train.add_argument('--steps', type=int, default=2500, required=False, help='training steps, default: %(default)s')
|
102 |
+
group_train.add_argument('--batch', type=int, default=1, required=False, help='batch size, default: %(default)s')
|
103 |
+
group_train.add_argument('--lr', type=float, default=1e-04, required=False, help='model learning rate, default: %(default)s')
|
104 |
+
group_train.add_argument('--dim', type=int, default=32, required=False, help='network dimension or number of vectors, default: %(default)s')
|
105 |
+
|
106 |
+
# lora params
|
107 |
+
group_train.add_argument('--repeats', type=int, default=1, required=False, help='number of repeats per image, default: %(default)s')
|
108 |
+
group_train.add_argument('--alpha', type=float, default=0, required=False, help='lora/lyco alpha for weights scaling, default: dim/2')
|
109 |
+
group_train.add_argument('--algo', type=str, default=None, choices=['locon', 'loha', 'lokr', 'ia3'], required=False, help='alternative lyco algoritm, default: %(default)s')
|
110 |
+
group_train.add_argument('--args', type=str, default=None, required=False, help='lora/lyco additional network arguments, default: %(default)s')
|
111 |
+
group_train.add_argument('--optimizer', type=str, default='AdamW', required=False, help='optimizer type, default: %(default)s')
|
112 |
+
group_train.add_argument('--precision', type=str, choices=['fp16', 'fp32'], default='fp16', required=False, help='training precision, default: %(default)s')
|
113 |
+
group_train.add_argument('--sdxl', default = False, action='store_true', help = "run sdxl training, default: %(default)s")
|
114 |
+
# AdamW (default), AdamW8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor
|
115 |
+
|
116 |
+
group_other = parser.add_argument_group('Other')
|
117 |
+
group_other.add_argument('--overwrite', default = False, action='store_true', help = "overwrite existing training, default: %(default)s")
|
118 |
+
group_other.add_argument('--experimental', default = False, action='store_true', help = "enable experimental options, default: %(default)s")
|
119 |
+
group_other.add_argument('--debug', default = False, action='store_true', help = "enable debug level logging, default: %(default)s")
|
120 |
+
|
121 |
+
args = parser.parse_args()
|
122 |
+
|
123 |
+
|
124 |
+
def prepare_server():
|
125 |
+
global server_ok # pylint: disable=global-statement
|
126 |
+
try:
|
127 |
+
server_status = util.Map(sdapi.progresssync())
|
128 |
+
server_state = server_status['state']
|
129 |
+
server_ok = True
|
130 |
+
except Exception:
|
131 |
+
log.warning(f'sdnext server error: {server_status}')
|
132 |
+
server_ok = False
|
133 |
+
if server_ok and server_state['job_count'] > 0:
|
134 |
+
log.error(f'sdnext server not idle: {server_state}')
|
135 |
+
exit(1)
|
136 |
+
if server_ok:
|
137 |
+
server_options = util.Map(sdapi.options())
|
138 |
+
server_options.options.save_training_settings_to_txt = False
|
139 |
+
server_options.options.training_enable_tensorboard = False
|
140 |
+
server_options.options.training_tensorboard_save_images = False
|
141 |
+
server_options.options.pin_memory = True
|
142 |
+
server_options.options.save_optimizer_state = False
|
143 |
+
server_options.options.training_image_repeats_per_epoch = args.repeats
|
144 |
+
server_options.options.training_write_csv_every = 0
|
145 |
+
sdapi.postsync('/sdapi/v1/options', server_options.options)
|
146 |
+
log.info('updated server options')
|
147 |
+
|
148 |
+
|
149 |
+
def verify_args():
|
150 |
+
server_options = util.Map(sdapi.options())
|
151 |
+
if args.model != '':
|
152 |
+
if not os.path.isfile(args.model):
|
153 |
+
log.error(f'cannot find loaded model: {args.model}')
|
154 |
+
exit(1)
|
155 |
+
if server_ok:
|
156 |
+
server_options.options.sd_model_checkpoint = args.model
|
157 |
+
sdapi.postsync('/sdapi/v1/options', server_options.options)
|
158 |
+
elif server_ok:
|
159 |
+
args.model = server_options.options.sd_model_checkpoint.split(' [')[0]
|
160 |
+
if args.sdxl and (server_options.sd_backend != 'diffusers' or server_options.diffusers_pipeline != 'Stable Diffusion XL'):
|
161 |
+
log.warning('server checkpoint is not sdxl')
|
162 |
+
else:
|
163 |
+
log.error('no model specified')
|
164 |
+
exit(1)
|
165 |
+
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
166 |
+
if args.type == 'lora' and not server_ok and not args.dir:
|
167 |
+
log.error('offline lora training requires --dir <lora folder>')
|
168 |
+
exit(1)
|
169 |
+
if args.type == 'lora':
|
170 |
+
import transformers
|
171 |
+
if transformers.__version__ != '4.30.2':
|
172 |
+
log.error(f'lora training requires specific transformers version: current {transformers.__version__} required transformers==4.30.2')
|
173 |
+
exit(1)
|
174 |
+
args.lora_dir = server_options.options.lora_dir or args.dir
|
175 |
+
if not os.path.isabs(args.lora_dir):
|
176 |
+
args.lora_dir = os.path.join(base_dir, args.lora_dir)
|
177 |
+
args.lyco_dir = server_options.options.lyco_dir or args.dir
|
178 |
+
if not os.path.isabs(args.lyco_dir):
|
179 |
+
args.lyco_dir = os.path.join(base_dir, args.lyco_dir)
|
180 |
+
args.embeddings_dir = server_options.options.embeddings_dir or args.dir
|
181 |
+
if not os.path.isfile(args.model):
|
182 |
+
args.ckpt_dir = server_options.options.ckpt_dir
|
183 |
+
if not os.path.isabs(args.ckpt_dir):
|
184 |
+
args.ckpt_dir = os.path.join(base_dir, args.ckpt_dir)
|
185 |
+
attempt = os.path.abspath(os.path.join(args.ckpt_dir, args.model))
|
186 |
+
args.model = attempt if os.path.isfile(attempt) else args.model
|
187 |
+
if not os.path.isfile(args.model):
|
188 |
+
attempt = os.path.abspath(os.path.join(args.ckpt_dir, args.model + '.safetensors'))
|
189 |
+
args.model = attempt if os.path.isfile(attempt) else args.model
|
190 |
+
if not os.path.isfile(args.model):
|
191 |
+
log.error(f'cannot find loaded model: {args.model}')
|
192 |
+
exit(1)
|
193 |
+
if not os.path.exists(args.input) or not os.path.isdir(args.input):
|
194 |
+
log.error(f'cannot find training folder: {args.input}')
|
195 |
+
exit(1)
|
196 |
+
if not os.path.exists(args.lora_dir) or not os.path.isdir(args.lora_dir):
|
197 |
+
log.error(f'cannot find lora folder: {args.lora_dir}')
|
198 |
+
exit(1)
|
199 |
+
if not os.path.exists(args.lyco_dir) or not os.path.isdir(args.lyco_dir):
|
200 |
+
log.error(f'cannot find lyco folder: {args.lyco_dir}')
|
201 |
+
exit(1)
|
202 |
+
if args.interim != '':
|
203 |
+
args.process_dir = args.interim
|
204 |
+
else:
|
205 |
+
args.process_dir = os.path.join(tempfile.gettempdir(), 'train', args.name)
|
206 |
+
log.debug(f'args: {vars(args)}')
|
207 |
+
log.debug(f'server flags: {server_options.flags}')
|
208 |
+
log.debug(f'server options: {server_options.options}')
|
209 |
+
|
210 |
+
|
211 |
+
async def training_loop():
|
212 |
+
async def async_train():
|
213 |
+
res = await sdapi.post('/sdapi/v1/train/embedding', options.embedding)
|
214 |
+
log.info(f'train embedding result: {res}')
|
215 |
+
|
216 |
+
async def async_monitor():
|
217 |
+
from tqdm.rich import tqdm
|
218 |
+
await asyncio.sleep(3)
|
219 |
+
res = util.Map(sdapi.progress())
|
220 |
+
with tqdm(desc='train embedding', total=res.state.job_count) as pbar:
|
221 |
+
while res.state.job_no < res.state.job_count and not res.state.interrupted and not res.state.skipped:
|
222 |
+
await asyncio.sleep(2)
|
223 |
+
prev_job = res.state.job_no
|
224 |
+
res = util.Map(sdapi.progress())
|
225 |
+
loss = re.search(r"Loss: (.*?)(?=\<)", res.textinfo)
|
226 |
+
if loss:
|
227 |
+
pbar.set_postfix({ 'loss': loss.group(0) })
|
228 |
+
pbar.update(res.state.job_no - prev_job)
|
229 |
+
|
230 |
+
a = asyncio.create_task(async_train())
|
231 |
+
b = asyncio.create_task(async_monitor())
|
232 |
+
await asyncio.gather(a, b) # wait for both pipeline and monitor to finish
|
233 |
+
|
234 |
+
|
235 |
+
def train_embedding():
|
236 |
+
log.info(f'{args.type} options: {options.embedding}')
|
237 |
+
create_options = util.Map({
|
238 |
+
"name": args.name,
|
239 |
+
"num_vectors_per_token": args.dim,
|
240 |
+
"overwrite_old": False,
|
241 |
+
"init_text": args.tag,
|
242 |
+
})
|
243 |
+
fn = os.path.join(args.embeddings_dir, args.name) + '.pt'
|
244 |
+
if os.path.exists(fn) and args.overwrite:
|
245 |
+
log.warning(f'delete existing embedding {fn}')
|
246 |
+
os.remove(fn)
|
247 |
+
else:
|
248 |
+
log.error(f'embedding exists {fn}')
|
249 |
+
return
|
250 |
+
log.info(f'create embedding {create_options}')
|
251 |
+
res = sdapi.postsync('/sdapi/v1/create/embedding', create_options)
|
252 |
+
if 'info' in res and 'error' in res['info']: # formatted error
|
253 |
+
log.error(res.info)
|
254 |
+
elif 'info' in res: # no error
|
255 |
+
asyncio.run(training_loop())
|
256 |
+
else: # unknown error
|
257 |
+
log.error(f'create embedding error {res}')
|
258 |
+
|
259 |
+
|
260 |
+
def train_lora():
|
261 |
+
fn = os.path.join(options.lora.output_dir, args.name)
|
262 |
+
for ext in ['.ckpt', '.pt', '.safetensors']:
|
263 |
+
if os.path.exists(fn + ext):
|
264 |
+
if args.overwrite:
|
265 |
+
log.warning(f'delete existing lora: {fn + ext}')
|
266 |
+
os.remove(fn + ext)
|
267 |
+
else:
|
268 |
+
log.error(f'lora exists: {fn + ext}')
|
269 |
+
return
|
270 |
+
log.info(f'{args.type} options: {options.lora}')
|
271 |
+
# lora imports
|
272 |
+
lora_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'modules', 'lora'))
|
273 |
+
lycoris_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'modules', 'lycoris'))
|
274 |
+
sys.path.append(lora_path)
|
275 |
+
if args.type == 'lyco':
|
276 |
+
sys.path.append(lycoris_path)
|
277 |
+
log.debug('importing lora lib')
|
278 |
+
if not args.sdxl:
|
279 |
+
import train_network
|
280 |
+
trainer = train_network.NetworkTrainer()
|
281 |
+
trainer.train(options.lora)
|
282 |
+
else:
|
283 |
+
import sdxl_train_network
|
284 |
+
trainer = sdxl_train_network.SdxlNetworkTrainer()
|
285 |
+
trainer.train(options.lora)
|
286 |
+
if args.type == 'lyco':
|
287 |
+
log.debug('importing lycoris lib')
|
288 |
+
import importlib
|
289 |
+
_network_module = importlib.import_module(options.lora.network_module)
|
290 |
+
|
291 |
+
|
292 |
+
def prepare_options():
|
293 |
+
if args.type == 'embedding':
|
294 |
+
log.info('train embedding')
|
295 |
+
options.lora.in_json = None
|
296 |
+
if args.type == 'dreambooth':
|
297 |
+
log.info('train using dreambooth style training')
|
298 |
+
options.lora.vae_batch_size = args.batch
|
299 |
+
options.lora.in_json = None
|
300 |
+
if args.type == 'lora':
|
301 |
+
log.info('train using lora style training')
|
302 |
+
options.lora.output_dir = args.lora_dir
|
303 |
+
options.lora.in_json = os.path.join(args.process_dir, args.name + '.json')
|
304 |
+
if args.type == 'lyco':
|
305 |
+
log.info('train using lycoris network')
|
306 |
+
options.lora.output_dir = args.lora_dir
|
307 |
+
options.lora.network_module = 'lycoris.kohya'
|
308 |
+
options.lora.in_json = os.path.join(args.process_dir, args.name + '.json')
|
309 |
+
# lora specific
|
310 |
+
options.lora.save_model_as = 'safetensors'
|
311 |
+
options.lora.pretrained_model_name_or_path = args.model
|
312 |
+
options.lora.output_name = args.name
|
313 |
+
options.lora.max_train_steps = args.steps
|
314 |
+
options.lora.network_dim = args.dim
|
315 |
+
options.lora.network_alpha = args.dim // 2 if args.alpha == 0 else args.alpha
|
316 |
+
options.lora.network_args = []
|
317 |
+
options.lora.training_comment = args.comments
|
318 |
+
options.lora.sdpa = True
|
319 |
+
options.lora.optimizer_type = args.optimizer
|
320 |
+
if args.algo is not None:
|
321 |
+
options.lora.network_args.append(f'algo={args.algo}')
|
322 |
+
if args.args is not None:
|
323 |
+
for net_arg in args.args:
|
324 |
+
options.lora.network_args.append(net_arg)
|
325 |
+
options.lora.gradient_accumulation_steps = args.gradient
|
326 |
+
options.lora.learning_rate = args.lr
|
327 |
+
options.lora.train_batch_size = args.batch
|
328 |
+
options.lora.train_data_dir = args.process_dir
|
329 |
+
options.lora.no_half_vae = args.precision == 'fp16'
|
330 |
+
# embedding specific
|
331 |
+
options.embedding.embedding_name = args.name
|
332 |
+
options.embedding.learn_rate = str(args.lr)
|
333 |
+
options.embedding.batch_size = args.batch
|
334 |
+
options.embedding.steps = args.steps
|
335 |
+
options.embedding.data_root = args.process_dir
|
336 |
+
options.embedding.log_directory = os.path.join(args.process_dir, 'log')
|
337 |
+
options.embedding.gradient_step = args.gradient
|
338 |
+
|
339 |
+
|
340 |
+
def process_inputs():
|
341 |
+
import process
|
342 |
+
import filetype
|
343 |
+
pathlib.Path(args.process_dir).mkdir(parents=True, exist_ok=True)
|
344 |
+
processing_options = args.process.split(',') if isinstance(args.process, str) else args.process
|
345 |
+
processing_options = [opt.strip() for opt in re.split(',| ', args.process)]
|
346 |
+
log.info(f'processing steps: {processing_options}')
|
347 |
+
for step in processing_options:
|
348 |
+
if step not in valid_steps:
|
349 |
+
log.error(f'invalid processing step: {[step]}')
|
350 |
+
exit(1)
|
351 |
+
for root, _sub_dirs, folder in os.walk(args.input):
|
352 |
+
files = [os.path.join(root, f) for f in folder if filetype.is_image(os.path.join(root, f))]
|
353 |
+
log.info(f'processing input images: {len(files)}')
|
354 |
+
if os.path.exists(args.process_dir):
|
355 |
+
if args.overwrite:
|
356 |
+
log.warning(f'removing existing processed folder: {args.process_dir}')
|
357 |
+
shutil.rmtree(args.process_dir, ignore_errors=True)
|
358 |
+
else:
|
359 |
+
log.info(f'processed folder exists: {args.process_dir}')
|
360 |
+
steps = [step for step in processing_options if step in ['face', 'body', 'original']]
|
361 |
+
process.reset()
|
362 |
+
options.process.target_size = 1024 if args.sdxl else 512
|
363 |
+
metadata = {}
|
364 |
+
for step in steps:
|
365 |
+
if step == 'face':
|
366 |
+
opts = [step for step in processing_options if step not in ['body', 'original']]
|
367 |
+
if step == 'body':
|
368 |
+
opts = [step for step in processing_options if step not in ['face', 'original', 'upscale', 'restore']] # body does not perform upscale or restore
|
369 |
+
if step == 'original':
|
370 |
+
opts = [step for step in processing_options if step not in ['face', 'body', 'upscale', 'restore', 'blur', 'range', 'segment']] # original does not perform most steps
|
371 |
+
log.info(f'processing current step: {opts}')
|
372 |
+
tag = step
|
373 |
+
if tag == 'original' and args.tag is not None:
|
374 |
+
concept = args.tag.split(',')[0].strip()
|
375 |
+
else:
|
376 |
+
concept = step
|
377 |
+
if args.type in ['lora', 'lyco', 'dreambooth']:
|
378 |
+
folder = os.path.join(args.process_dir, str(args.repeats) + '_' + concept) # separate concepts per folder
|
379 |
+
if args.type in ['embedding']:
|
380 |
+
folder = os.path.join(args.process_dir) # everything into same folder
|
381 |
+
log.info(f'processing concept: {concept}')
|
382 |
+
log.info(f'processing output folder: {folder}')
|
383 |
+
pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
|
384 |
+
results = {}
|
385 |
+
if server_ok:
|
386 |
+
for f in files:
|
387 |
+
res = process.file(filename = f, folder = folder, tag = args.tag, requested = opts)
|
388 |
+
if res.image: # valid result
|
389 |
+
results[res.type] = results.get(res.type, 0) + 1
|
390 |
+
results['total'] = results.get('total', 0) + 1
|
391 |
+
rel_path = res.basename.replace(os.path.commonpath([res.basename, args.process_dir]), '')
|
392 |
+
if rel_path.startswith(os.path.sep):
|
393 |
+
rel_path = rel_path[1:]
|
394 |
+
metadata[rel_path] = { 'caption': res.caption, 'tags': ','.join(res.tags) }
|
395 |
+
if options.lora.in_json is None:
|
396 |
+
with open(res.output.replace(options.process.format, '.txt'), "w", encoding='utf-8') as outfile:
|
397 |
+
outfile.write(res.caption)
|
398 |
+
log.info(f"processing {'saved' if res.image is not None else 'skipped'}: {f} => {res.output} {res.ops} {res.message}")
|
399 |
+
else:
|
400 |
+
log.info('processing skipped: offline')
|
401 |
+
folders = [os.path.join(args.process_dir, folder) for folder in os.listdir(args.process_dir) if os.path.isdir(os.path.join(args.process_dir, folder))]
|
402 |
+
log.info(f'input datasets {folders}')
|
403 |
+
if options.lora.in_json is not None:
|
404 |
+
with open(options.lora.in_json, "w", encoding='utf-8') as outfile: # write json at the end only
|
405 |
+
outfile.write(json.dumps(metadata, indent=2))
|
406 |
+
for folder in folders: # create latents
|
407 |
+
import latents
|
408 |
+
latents.create_vae_latents(util.Map({ 'input': folder, 'json': options.lora.in_json }))
|
409 |
+
latents.unload_vae()
|
410 |
+
r = { 'inputs': len(files), 'outputs': results, 'metadata': options.lora.in_json }
|
411 |
+
log.info(f'processing steps result: {r}')
|
412 |
+
if args.gradient < 0:
|
413 |
+
log.info(f"setting gradient accumulation to number of images: {results['total']}")
|
414 |
+
options.lora.gradient_accumulation_steps = results['total']
|
415 |
+
options.embedding.gradient_step = results['total']
|
416 |
+
process.unload()
|
417 |
+
|
418 |
+
|
419 |
+
if __name__ == '__main__':
|
420 |
+
parse_args()
|
421 |
+
setup_logging()
|
422 |
+
log.info('SD.Next Train')
|
423 |
+
sdapi.sd_url = args.server
|
424 |
+
if args.user is not None:
|
425 |
+
sdapi.sd_username = args.user
|
426 |
+
if args.password is not None:
|
427 |
+
sdapi.sd_password = args.password
|
428 |
+
prepare_server()
|
429 |
+
verify_args()
|
430 |
+
prepare_options()
|
431 |
+
mem_stats()
|
432 |
+
process_inputs()
|
433 |
+
mem_stats()
|
434 |
+
try:
|
435 |
+
if args.type == 'embedding':
|
436 |
+
train_embedding()
|
437 |
+
if args.type == 'lora' or args.type == 'lyco' or args.type == 'dreambooth':
|
438 |
+
train_lora()
|
439 |
+
except KeyboardInterrupt:
|
440 |
+
log.error('interrupt requested')
|
441 |
+
sdapi.interrupt()
|
442 |
+
mem_stats()
|
443 |
+
log.info('done')
|
cli/util.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
generic helper methods
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import string
|
8 |
+
import logging
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
log_format = '%(asctime)s %(levelname)s: %(message)s'
|
12 |
+
logging.basicConfig(level = logging.INFO, format = log_format)
|
13 |
+
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
|
14 |
+
warnings.filterwarnings(action="ignore", category=FutureWarning)
|
15 |
+
warnings.filterwarnings(action="ignore", category=UserWarning)
|
16 |
+
log = logging.getLogger("sd")
|
17 |
+
|
18 |
+
|
19 |
+
def set_logfile(logfile):
|
20 |
+
fh = logging.FileHandler(logfile)
|
21 |
+
formatter = logging.Formatter(log_format)
|
22 |
+
fh.setLevel(log.getEffectiveLevel())
|
23 |
+
fh.setFormatter(formatter)
|
24 |
+
log.addHandler(fh)
|
25 |
+
log.info({ 'log file': logfile })
|
26 |
+
|
27 |
+
|
28 |
+
def safestring(text: str):
|
29 |
+
lines = []
|
30 |
+
for line in text.splitlines():
|
31 |
+
lines.append(line.translate(str.maketrans('', '', string.punctuation)).strip())
|
32 |
+
res = ', '.join(lines)
|
33 |
+
return res[:1000]
|
34 |
+
|
35 |
+
|
36 |
+
def get_memory():
|
37 |
+
def gb(val: float):
|
38 |
+
return round(val / 1024 / 1024 / 1024, 2)
|
39 |
+
mem = {}
|
40 |
+
try:
|
41 |
+
import psutil
|
42 |
+
process = psutil.Process(os.getpid())
|
43 |
+
res = process.memory_info()
|
44 |
+
ram_total = 100 * res.rss / process.memory_percent()
|
45 |
+
ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) }
|
46 |
+
mem.update({ 'ram': ram })
|
47 |
+
except Exception as e:
|
48 |
+
mem.update({ 'ram': e })
|
49 |
+
try:
|
50 |
+
import torch
|
51 |
+
if torch.cuda.is_available():
|
52 |
+
s = torch.cuda.mem_get_info()
|
53 |
+
gpu = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) }
|
54 |
+
s = dict(torch.cuda.memory_stats('cuda'))
|
55 |
+
allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) }
|
56 |
+
reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) }
|
57 |
+
active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) }
|
58 |
+
inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) }
|
59 |
+
events = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
|
60 |
+
mem.update({
|
61 |
+
'gpu': gpu,
|
62 |
+
'gpu-active': active,
|
63 |
+
'gpu-allocated': allocated,
|
64 |
+
'gpu-reserved': reserved,
|
65 |
+
'gpu-inactive': inactive,
|
66 |
+
'events': events,
|
67 |
+
})
|
68 |
+
except Exception:
|
69 |
+
pass
|
70 |
+
return Map(mem)
|
71 |
+
|
72 |
+
|
73 |
+
class Map(dict): # pylint: disable=C0205
|
74 |
+
__slots__ = ('__dict__') # pylint: disable=superfluous-parens
|
75 |
+
def __init__(self, *args, **kwargs):
|
76 |
+
super(Map, self).__init__(*args, **kwargs) # pylint: disable=super-with-arguments
|
77 |
+
for arg in args:
|
78 |
+
if isinstance(arg, dict):
|
79 |
+
for k, v in arg.items():
|
80 |
+
if isinstance(v, dict):
|
81 |
+
v = Map(v)
|
82 |
+
if isinstance(v, list):
|
83 |
+
self.__convert(v)
|
84 |
+
self[k] = v
|
85 |
+
if kwargs:
|
86 |
+
for k, v in kwargs.items():
|
87 |
+
if isinstance(v, dict):
|
88 |
+
v = Map(v)
|
89 |
+
elif isinstance(v, list):
|
90 |
+
self.__convert(v)
|
91 |
+
self[k] = v
|
92 |
+
def __convert(self, v):
|
93 |
+
for elem in range(0, len(v)): # pylint: disable=consider-using-enumerate
|
94 |
+
if isinstance(v[elem], dict):
|
95 |
+
v[elem] = Map(v[elem])
|
96 |
+
elif isinstance(v[elem], list):
|
97 |
+
self.__convert(v[elem])
|
98 |
+
def __getattr__(self, attr):
|
99 |
+
return self.get(attr)
|
100 |
+
def __setattr__(self, key, value):
|
101 |
+
self.__setitem__(key, value)
|
102 |
+
def __setitem__(self, key, value):
|
103 |
+
super(Map, self).__setitem__(key, value) # pylint: disable=super-with-arguments
|
104 |
+
self.__dict__.update({key: value})
|
105 |
+
def __delattr__(self, item):
|
106 |
+
self.__delitem__(item)
|
107 |
+
def __delitem__(self, key):
|
108 |
+
super(Map, self).__delitem__(key) # pylint: disable=super-with-arguments
|
109 |
+
del self.__dict__[key]
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
pass
|
cli/validate-locale.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import json
|
6 |
+
from rich import print # pylint: disable=redefined-builtin
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
sys.argv.pop(0)
|
10 |
+
fn = sys.argv[0] if len(sys.argv) > 0 else 'locale_en.json'
|
11 |
+
if not os.path.isfile(fn):
|
12 |
+
print(f'File not found: {fn}')
|
13 |
+
sys.exit(1)
|
14 |
+
with open(fn, 'r', encoding="utf-8") as f:
|
15 |
+
data = json.load(f)
|
16 |
+
keys = []
|
17 |
+
t_names = 0
|
18 |
+
t_hints = 0
|
19 |
+
t_localized = 0
|
20 |
+
t_long = 0
|
21 |
+
for k in data.keys():
|
22 |
+
names = len(data[k])
|
23 |
+
t_names += names
|
24 |
+
hints = len([k for k in data[k] if k["hint"] != ""])
|
25 |
+
t_hints += hints
|
26 |
+
localized = len([k for k in data[k] if k["localized"] != ""])
|
27 |
+
t_localized += localized
|
28 |
+
missing = names - hints
|
29 |
+
long = 0
|
30 |
+
for v in data[k]:
|
31 |
+
if v['label'] in keys:
|
32 |
+
print(f' Duplicate: {k}.{v["label"]}')
|
33 |
+
else:
|
34 |
+
if len(v['label']) > 63:
|
35 |
+
long += 1
|
36 |
+
print(f' Long label: {k}.{v["label"]}')
|
37 |
+
keys.append(v['label'])
|
38 |
+
t_long += long
|
39 |
+
print(f'Section: [bold magenta]{k.ljust(20)}[/bold magenta] entries={names} localized={"[bold green]" + str(localized) + "[/bold green]" if localized > 0 else "0"} long={"[bold red]" + str(long) + "[/bold red]" if long > 0 else "0"} hints={hints} missing={"[bold red]" + str(missing) + "[/bold red]" if missing > 0 else "[bold green]0[/bold green]"}')
|
40 |
+
print(f'Totals: entries={t_names} localized={localized} long={t_long} hints={t_hints} missing={t_names - t_hints}')
|
cli/video-extract.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
use ffmpeg for animation processing
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import subprocess
|
8 |
+
import pathlib
|
9 |
+
import argparse
|
10 |
+
import filetype
|
11 |
+
from util import log, Map
|
12 |
+
|
13 |
+
|
14 |
+
def probe(src: str):
|
15 |
+
cmd = f"ffprobe -hide_banner -loglevel 0 -print_format json -show_format -show_streams \"{src}\""
|
16 |
+
result = subprocess.run(cmd, shell = True, capture_output = True, text = True, check = True)
|
17 |
+
data = json.loads(result.stdout)
|
18 |
+
stream = [x for x in data['streams'] if x["codec_type"] == "video"][0]
|
19 |
+
fmt = data['format'] if 'format' in data else {}
|
20 |
+
res = {**stream, **fmt}
|
21 |
+
video = Map({
|
22 |
+
'codec': res.get('codec_name', 'unknown') + '/' + res.get('codec_tag_string', ''),
|
23 |
+
'resolution': [int(res.get('width', 0)), int(res.get('height', 0))],
|
24 |
+
'duration': float(res.get('duration', 0)),
|
25 |
+
'frames': int(res.get('nb_frames', 0)),
|
26 |
+
'bitrate': round(float(res.get('bit_rate', 0)) / 1024),
|
27 |
+
})
|
28 |
+
return video
|
29 |
+
|
30 |
+
|
31 |
+
def extract(src: str, dst: str, rate: float = 0.015, fps: float = 0, start = 0, end = 0):
|
32 |
+
images = []
|
33 |
+
if not os.path.isfile(src) or not filetype.is_video(src):
|
34 |
+
log.error({ 'extract': 'input is not movie file' })
|
35 |
+
return 0
|
36 |
+
dst = dst if dst.endswith('/') else dst + '/'
|
37 |
+
|
38 |
+
video = probe(src)
|
39 |
+
log.info({ 'extract': { 'source': src, **video } })
|
40 |
+
|
41 |
+
ssstart = f' -ss {start}' if start > 0 else ''
|
42 |
+
ssend = f' -to {video.duration - end}' if start > 0 else ''
|
43 |
+
filename = pathlib.Path(src).stem
|
44 |
+
if rate > 0:
|
45 |
+
cmd = f"ffmpeg -hide_banner -y -loglevel info {ssstart} {ssend} -i \"{src}\" -filter:v \"select='gt(scene,{rate})',metadata=print\" -vsync vfr -frame_pts 1 \"{dst}{filename}-%05d.jpg\""
|
46 |
+
elif fps > 0:
|
47 |
+
cmd = f"ffmpeg -hide_banner -y -loglevel info {ssstart} {ssend} -i \"{src}\" -r {fps} -vsync vfr -frame_pts 1 \"{dst}{filename}-%05d.jpg\""
|
48 |
+
else:
|
49 |
+
log.error({ 'extract': 'requires either rate or fps' })
|
50 |
+
return 0
|
51 |
+
log.debug({ 'extract': cmd })
|
52 |
+
pathlib.Path(dst).mkdir(parents = True, exist_ok = True)
|
53 |
+
result = subprocess.run(cmd, shell = True, capture_output = True, text = True, check = True)
|
54 |
+
for line in result.stderr.split('\n'):
|
55 |
+
if 'pts_time' in line:
|
56 |
+
log.debug({ 'extract': { 'keyframe': line.strip().split(' ')[-1].split(':')[-1] } })
|
57 |
+
images = next(os.walk(dst))[2]
|
58 |
+
log.info({ 'extract': { 'destination': dst, 'keyframes': len(images), 'rate': rate, 'fps': fps } })
|
59 |
+
return len(images)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
parser = argparse.ArgumentParser(description="ffmpeg pipeline")
|
64 |
+
parser.add_argument("--input", type = str, required = True, help="input")
|
65 |
+
parser.add_argument("--output", type = str, required = True, help="output")
|
66 |
+
parser.add_argument("--rate", type = float, default = 0, required = False, help="extraction change rate threshold")
|
67 |
+
parser.add_argument("--fps", type = float, default = 0, required = False, help="extraction frames per second")
|
68 |
+
parser.add_argument("--skipstart", type = float, default = 1, required = False, help="skip time from start of video")
|
69 |
+
parser.add_argument("--skipend", type = float, default = 1, required = False, help="skip time to end of video")
|
70 |
+
params = parser.parse_args()
|
71 |
+
extract(src = params.input, dst = params.output, rate = params.rate, fps = params.fps, start = params.skipstart, end = params.skipend)
|
configs/alt-diffusion-inference.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: modules.xlmr.BertSeriesModelWithTransformation
|
71 |
+
params:
|
72 |
+
name: "XLMR-Large"
|
configs/instruct-pix2pix.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
model:
|
5 |
+
base_learning_rate: 1.0e-04
|
6 |
+
target: modules.hijack.ddpm_edit.LatentDiffusion
|
7 |
+
params:
|
8 |
+
linear_start: 0.00085
|
9 |
+
linear_end: 0.0120
|
10 |
+
num_timesteps_cond: 1
|
11 |
+
log_every_t: 200
|
12 |
+
timesteps: 1000
|
13 |
+
first_stage_key: edited
|
14 |
+
cond_stage_key: edit
|
15 |
+
# image_size: 64
|
16 |
+
# image_size: 32
|
17 |
+
image_size: 16
|
18 |
+
channels: 4
|
19 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
20 |
+
conditioning_key: hybrid
|
21 |
+
monitor: val/loss_simple_ema
|
22 |
+
scale_factor: 0.18215
|
23 |
+
use_ema: false
|
24 |
+
|
25 |
+
scheduler_config: # 10000 warmup steps
|
26 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
27 |
+
params:
|
28 |
+
warm_up_steps: [ 0 ]
|
29 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
30 |
+
f_start: [ 1.e-6 ]
|
31 |
+
f_max: [ 1. ]
|
32 |
+
f_min: [ 1. ]
|
33 |
+
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32 # unused
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions: [ 4, 2, 1 ]
|
42 |
+
num_res_blocks: 2
|
43 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
44 |
+
num_heads: 8
|
45 |
+
use_spatial_transformer: True
|
46 |
+
transformer_depth: 1
|
47 |
+
context_dim: 768
|
48 |
+
use_checkpoint: True
|
49 |
+
legacy: False
|
50 |
+
|
51 |
+
first_stage_config:
|
52 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
53 |
+
params:
|
54 |
+
embed_dim: 4
|
55 |
+
monitor: val/rec_loss
|
56 |
+
ddconfig:
|
57 |
+
double_z: true
|
58 |
+
z_channels: 4
|
59 |
+
resolution: 256
|
60 |
+
in_channels: 3
|
61 |
+
out_ch: 3
|
62 |
+
ch: 128
|
63 |
+
ch_mult:
|
64 |
+
- 1
|
65 |
+
- 2
|
66 |
+
- 4
|
67 |
+
- 4
|
68 |
+
num_res_blocks: 2
|
69 |
+
attn_resolutions: []
|
70 |
+
dropout: 0.0
|
71 |
+
lossconfig:
|
72 |
+
target: torch.nn.Identity
|
73 |
+
|
74 |
+
cond_stage_config:
|
75 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
76 |
+
|
77 |
+
data:
|
78 |
+
target: main.DataModuleFromConfig
|
79 |
+
params:
|
80 |
+
batch_size: 128
|
81 |
+
num_workers: 1
|
82 |
+
wrap: false
|
83 |
+
validation:
|
84 |
+
target: edit_dataset.EditDataset
|
85 |
+
params:
|
86 |
+
path: data/clip-filtered-dataset
|
87 |
+
cache_dir: data/
|
88 |
+
cache_name: data_10k
|
89 |
+
split: val
|
90 |
+
min_text_sim: 0.2
|
91 |
+
min_image_sim: 0.75
|
92 |
+
min_direction_sim: 0.2
|
93 |
+
max_samples_per_prompt: 1
|
94 |
+
min_resize_res: 512
|
95 |
+
max_resize_res: 512
|
96 |
+
crop_res: 512
|
97 |
+
output_as_edit: False
|
98 |
+
real_input: True
|