FrankZxShen
commited on
Commit
•
aa69275
1
Parent(s):
e724d71
Upload 55 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .gitignore +133 -0
- LICENSE +201 -0
- README.md +91 -3
- assets/demo.gif +3 -0
- assets/demo_short.gif +3 -0
- assets/figure.jpg +3 -0
- attentions.py +300 -0
- commons.py +96 -0
- hf_models/download.md +33 -0
- hubert_model.py +221 -0
- image/01cc9083.png +0 -0
- image/1d988a81.png +0 -0
- image/307ade76.png +0 -0
- image/5ebacb6a.png +0 -0
- image/cdb4d5e5.png +0 -0
- mel_processing.py +101 -0
- models.py +404 -0
- modules/__init__.py +0 -0
- modules/controlnet_canny.py +60 -0
- modules/controlnet_depth.py +59 -0
- modules/controlnet_hed.py +58 -0
- modules/controlnet_line.py +58 -0
- modules/controlnet_normal.py +71 -0
- modules/controlnet_pose.py +58 -0
- modules/controlnet_scibble.py +56 -0
- modules/controlnet_seg.py +104 -0
- modules/image_captioning.py +21 -0
- modules/image_editing.py +40 -0
- modules/instruct_px2pix.py +28 -0
- modules/mask_former.py +30 -0
- modules/text2img.py +26 -0
- modules/utils.py +75 -0
- modules/visual_question_answering.py +24 -0
- requirement.txt +48 -0
- text/__init__.py +32 -0
- text/cantonese.py +59 -0
- text/cleaners.py +145 -0
- text/english.py +188 -0
- text/japanese.py +153 -0
- text/korean.py +210 -0
- text/mandarin.py +330 -0
- text/ngu_dialect.py +30 -0
- text/sanskrit.py +62 -0
- text/shanghainese.py +64 -0
- text/thai.py +44 -0
- transforms.py +193 -0
- utils_vits.py +75 -0
- visual_chatgpt.py +908 -0
- visual_chatgpt_zh.py +171 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
assets/demo_short.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/figure.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
image/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
pip-wheel-metadata/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
.python-version
|
87 |
+
|
88 |
+
# pipenv
|
89 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
90 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
91 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
92 |
+
# install all needed dependencies.
|
93 |
+
#Pipfile.lock
|
94 |
+
|
95 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
96 |
+
__pypackages__/
|
97 |
+
|
98 |
+
# Celery stuff
|
99 |
+
celerybeat-schedule
|
100 |
+
celerybeat.pid
|
101 |
+
|
102 |
+
# SageMath parsed files
|
103 |
+
*.sage.py
|
104 |
+
|
105 |
+
# Environments
|
106 |
+
.env
|
107 |
+
.venv
|
108 |
+
env/
|
109 |
+
venv/
|
110 |
+
ENV/
|
111 |
+
env.bak/
|
112 |
+
venv.bak/
|
113 |
+
|
114 |
+
# Spyder project settings
|
115 |
+
.spyderproject
|
116 |
+
.spyproject
|
117 |
+
|
118 |
+
# Rope project settings
|
119 |
+
.ropeproject
|
120 |
+
|
121 |
+
# mkdocs documentation
|
122 |
+
/site
|
123 |
+
|
124 |
+
# mypy
|
125 |
+
.mypy_cache/
|
126 |
+
.dmypy.json
|
127 |
+
dmypy.json
|
128 |
+
|
129 |
+
# Pyre type checker
|
130 |
+
.pyre/
|
131 |
+
annotator
|
132 |
+
cldm
|
133 |
+
ldm
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,91 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# visual-chatgpt-zh-vits
|
2 |
+
visual-chatgpt支持中文的windows版本
|
3 |
+
|
4 |
+
融合vits推断模块
|
5 |
+
|
6 |
+
|
7 |
+
官方论文: [<font size=5>Visual ChatGPT: Talking, Drawing and Editing with Visual Foundation Models</font>](https://arxiv.org/abs/2303.04671)
|
8 |
+
|
9 |
+
官方仓库:[visual-chatgpt](https://github.com/microsoft/visual-chatgpt)
|
10 |
+
|
11 |
+
fork from:[visual-chatgpt-zh](https://github.com/wxj630/visual-chatgpt-zh)
|
12 |
+
|
13 |
+
|
14 |
+
## Demo
|
15 |
+
<img src="./assets/demo_short.gif" width="750">
|
16 |
+
|
17 |
+
## System Architecture
|
18 |
+
|
19 |
+
|
20 |
+
<p align="center"><img src="./assets/figure.jpg" alt="Logo"></p>
|
21 |
+
|
22 |
+
|
23 |
+
## Quick Start
|
24 |
+
|
25 |
+
```
|
26 |
+
# 1、下载代码
|
27 |
+
git clone https://github.com/FrankZxShen/visual-chatgpt-zh-vits.git
|
28 |
+
|
29 |
+
# 2、进入项目目录
|
30 |
+
cd visual-chatgpt-zh-vits
|
31 |
+
|
32 |
+
# 3、创建python环境并激活环境
|
33 |
+
conda create -n visgpt python=3.8
|
34 |
+
activate visgpt
|
35 |
+
|
36 |
+
# 4、安装环境依赖
|
37 |
+
pip install -r requirement.txt
|
38 |
+
|
39 |
+
# 5、确认api key
|
40 |
+
export OPENAI_API_KEY={Your_Private_Openai_Key}
|
41 |
+
# windows系统用set命令而不是export
|
42 |
+
set OPENAI_API_KEY={Your_Private_Openai_Key}
|
43 |
+
|
44 |
+
# 6、下载hf模型到指定目录
|
45 |
+
# 具体模型文件地址于hf_models
|
46 |
+
# 若需要vits推断功能将G.pth config.json放于vits_models下(目前仅支持日语?)
|
47 |
+
# Windows:下载pyopenjtalk Windows于text下
|
48 |
+
|
49 |
+
# 7、启动系统,这个例子我们加载了ImageCaptioning和Text2Image两个模型,
|
50 |
+
python visual_chatgpt_zh_vits.py
|
51 |
+
# 想要用哪个功能就可增加一些模型加载
|
52 |
+
python visual_chatgpt_zh_vits.py
|
53 |
+
--load ImageCaptioning_cuda:0,Text2Image_cuda:0 \
|
54 |
+
--pretrained_model_dir {your_hf_models_path} \
|
55 |
+
|
56 |
+
# 8、可以直接在visual_chatgpt_zh_vits.py 38行修改key 若需要vits 39行设定True
|
57 |
+
```
|
58 |
+
|
59 |
+
原作者:根据官方建议,不同显卡可以指定不同“--load”参数,显存不够的就可以时间换空间,把不重要的模型加载到cpu上,虽然推理慢但是好歹能跑不是?(手动狗头):
|
60 |
+
```
|
61 |
+
# Advice for CPU Users
|
62 |
+
python visual_chatgpt.py --load ImageCaptioning_cpu,Text2Image_cpu
|
63 |
+
|
64 |
+
# Advice for 1 Tesla T4 15GB (Google Colab)
|
65 |
+
python visual_chatgpt.py --load "ImageCaptioning_cuda:0,Text2Image_cuda:0"
|
66 |
+
|
67 |
+
# Advice for 4 Tesla V100 32GB
|
68 |
+
python visual_chatgpt.py --load "ImageCaptioning_cuda:0,ImageEditing_cuda:0,
|
69 |
+
Text2Image_cuda:1,Image2Canny_cpu,CannyText2Image_cuda:1,
|
70 |
+
Image2Depth_cpu,DepthText2Image_cuda:1,VisualQuestionAnswering_cuda:2,
|
71 |
+
InstructPix2Pix_cuda:2,Image2Scribble_cpu,ScribbleText2Image_cuda:2,
|
72 |
+
Image2Seg_cpu,SegText2Image_cuda:2,Image2Pose_cpu,PoseText2Image_cuda:2,
|
73 |
+
Image2Hed_cpu,HedText2Image_cuda:3,Image2Normal_cpu,
|
74 |
+
NormalText2Image_cuda:3,Image2Line_cpu,LineText2Image_cuda:3"
|
75 |
+
```
|
76 |
+
|
77 |
+
实测环境 Windows RTX3070 8G:若只需要ImageCaptioning和Text2Image两个模型的功能,对显存要求极低,理论上能跑AI绘图均可以(>4G,但速度很慢)?
|
78 |
+
|
79 |
+
## limitations
|
80 |
+
|
81 |
+
img无法显示在gradio上?
|
82 |
+
|
83 |
+
## Acknowledgement
|
84 |
+
|
85 |
+
We appreciate the open source of the following projects:
|
86 |
+
|
87 |
+
- HuggingFace [[Project]](https://github.com/huggingface/transformers)
|
88 |
+
|
89 |
+
- ControlNet [[Paper]](https://arxiv.org/abs/2302.05543) [[Project]](https://github.com/lllyasviel/ControlNet)
|
90 |
+
|
91 |
+
- Stable Diffusion [[Paper]](https://arxiv.org/abs/2112.10752) [[Project]](https://github.com/CompVis/stable-diffusion)
|
assets/demo.gif
ADDED
Git LFS Details
|
assets/demo_short.gif
ADDED
Git LFS Details
|
assets/figure.jpg
ADDED
Git LFS Details
|
attentions.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
import commons
|
7 |
+
from vits_modules import LayerNorm
|
8 |
+
|
9 |
+
|
10 |
+
class Encoder(nn.Module):
|
11 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
|
12 |
+
super().__init__()
|
13 |
+
self.hidden_channels = hidden_channels
|
14 |
+
self.filter_channels = filter_channels
|
15 |
+
self.n_heads = n_heads
|
16 |
+
self.n_layers = n_layers
|
17 |
+
self.kernel_size = kernel_size
|
18 |
+
self.p_dropout = p_dropout
|
19 |
+
self.window_size = window_size
|
20 |
+
|
21 |
+
self.drop = nn.Dropout(p_dropout)
|
22 |
+
self.attn_layers = nn.ModuleList()
|
23 |
+
self.norm_layers_1 = nn.ModuleList()
|
24 |
+
self.ffn_layers = nn.ModuleList()
|
25 |
+
self.norm_layers_2 = nn.ModuleList()
|
26 |
+
for i in range(self.n_layers):
|
27 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
28 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
29 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
30 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
31 |
+
|
32 |
+
def forward(self, x, x_mask):
|
33 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
34 |
+
x = x * x_mask
|
35 |
+
for i in range(self.n_layers):
|
36 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
37 |
+
y = self.drop(y)
|
38 |
+
x = self.norm_layers_1[i](x + y)
|
39 |
+
|
40 |
+
y = self.ffn_layers[i](x, x_mask)
|
41 |
+
y = self.drop(y)
|
42 |
+
x = self.norm_layers_2[i](x + y)
|
43 |
+
x = x * x_mask
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class Decoder(nn.Module):
|
48 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
|
49 |
+
super().__init__()
|
50 |
+
self.hidden_channels = hidden_channels
|
51 |
+
self.filter_channels = filter_channels
|
52 |
+
self.n_heads = n_heads
|
53 |
+
self.n_layers = n_layers
|
54 |
+
self.kernel_size = kernel_size
|
55 |
+
self.p_dropout = p_dropout
|
56 |
+
self.proximal_bias = proximal_bias
|
57 |
+
self.proximal_init = proximal_init
|
58 |
+
|
59 |
+
self.drop = nn.Dropout(p_dropout)
|
60 |
+
self.self_attn_layers = nn.ModuleList()
|
61 |
+
self.norm_layers_0 = nn.ModuleList()
|
62 |
+
self.encdec_attn_layers = nn.ModuleList()
|
63 |
+
self.norm_layers_1 = nn.ModuleList()
|
64 |
+
self.ffn_layers = nn.ModuleList()
|
65 |
+
self.norm_layers_2 = nn.ModuleList()
|
66 |
+
for i in range(self.n_layers):
|
67 |
+
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
|
68 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
69 |
+
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
70 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
71 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
|
72 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
73 |
+
|
74 |
+
def forward(self, x, x_mask, h, h_mask):
|
75 |
+
"""
|
76 |
+
x: decoder input
|
77 |
+
h: encoder output
|
78 |
+
"""
|
79 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
80 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
81 |
+
x = x * x_mask
|
82 |
+
for i in range(self.n_layers):
|
83 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
84 |
+
y = self.drop(y)
|
85 |
+
x = self.norm_layers_0[i](x + y)
|
86 |
+
|
87 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
88 |
+
y = self.drop(y)
|
89 |
+
x = self.norm_layers_1[i](x + y)
|
90 |
+
|
91 |
+
y = self.ffn_layers[i](x, x_mask)
|
92 |
+
y = self.drop(y)
|
93 |
+
x = self.norm_layers_2[i](x + y)
|
94 |
+
x = x * x_mask
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class MultiHeadAttention(nn.Module):
|
99 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
100 |
+
super().__init__()
|
101 |
+
assert channels % n_heads == 0
|
102 |
+
|
103 |
+
self.channels = channels
|
104 |
+
self.out_channels = out_channels
|
105 |
+
self.n_heads = n_heads
|
106 |
+
self.p_dropout = p_dropout
|
107 |
+
self.window_size = window_size
|
108 |
+
self.heads_share = heads_share
|
109 |
+
self.block_length = block_length
|
110 |
+
self.proximal_bias = proximal_bias
|
111 |
+
self.proximal_init = proximal_init
|
112 |
+
self.attn = None
|
113 |
+
|
114 |
+
self.k_channels = channels // n_heads
|
115 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
116 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
117 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
118 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
119 |
+
self.drop = nn.Dropout(p_dropout)
|
120 |
+
|
121 |
+
if window_size is not None:
|
122 |
+
n_heads_rel = 1 if heads_share else n_heads
|
123 |
+
rel_stddev = self.k_channels**-0.5
|
124 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
125 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
126 |
+
|
127 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
128 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
129 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
130 |
+
if proximal_init:
|
131 |
+
with torch.no_grad():
|
132 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
133 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
134 |
+
|
135 |
+
def forward(self, x, c, attn_mask=None):
|
136 |
+
q = self.conv_q(x)
|
137 |
+
k = self.conv_k(c)
|
138 |
+
v = self.conv_v(c)
|
139 |
+
|
140 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
141 |
+
|
142 |
+
x = self.conv_o(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
def attention(self, query, key, value, mask=None):
|
146 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
147 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
148 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
149 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
150 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
151 |
+
|
152 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
153 |
+
if self.window_size is not None:
|
154 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
155 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
156 |
+
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
|
157 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
158 |
+
scores = scores + scores_local
|
159 |
+
if self.proximal_bias:
|
160 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
161 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
162 |
+
if mask is not None:
|
163 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
164 |
+
if self.block_length is not None:
|
165 |
+
assert t_s == t_t, "Local attention is only available for self-attention."
|
166 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
167 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
168 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
169 |
+
p_attn = self.drop(p_attn)
|
170 |
+
output = torch.matmul(p_attn, value)
|
171 |
+
if self.window_size is not None:
|
172 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
173 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
174 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
175 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
176 |
+
return output, p_attn
|
177 |
+
|
178 |
+
def _matmul_with_relative_values(self, x, y):
|
179 |
+
"""
|
180 |
+
x: [b, h, l, m]
|
181 |
+
y: [h or 1, m, d]
|
182 |
+
ret: [b, h, l, d]
|
183 |
+
"""
|
184 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
185 |
+
return ret
|
186 |
+
|
187 |
+
def _matmul_with_relative_keys(self, x, y):
|
188 |
+
"""
|
189 |
+
x: [b, h, l, d]
|
190 |
+
y: [h or 1, m, d]
|
191 |
+
ret: [b, h, l, m]
|
192 |
+
"""
|
193 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
194 |
+
return ret
|
195 |
+
|
196 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
197 |
+
max_relative_position = 2 * self.window_size + 1
|
198 |
+
# Pad first before slice to avoid using cond ops.
|
199 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
200 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
201 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
202 |
+
if pad_length > 0:
|
203 |
+
padded_relative_embeddings = F.pad(
|
204 |
+
relative_embeddings,
|
205 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
206 |
+
else:
|
207 |
+
padded_relative_embeddings = relative_embeddings
|
208 |
+
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
|
209 |
+
return used_relative_embeddings
|
210 |
+
|
211 |
+
def _relative_position_to_absolute_position(self, x):
|
212 |
+
"""
|
213 |
+
x: [b, h, l, 2*l-1]
|
214 |
+
ret: [b, h, l, l]
|
215 |
+
"""
|
216 |
+
batch, heads, length, _ = x.size()
|
217 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
218 |
+
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
219 |
+
|
220 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
221 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
222 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
223 |
+
|
224 |
+
# Reshape and slice out the padded elements.
|
225 |
+
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
226 |
+
return x_final
|
227 |
+
|
228 |
+
def _absolute_position_to_relative_position(self, x):
|
229 |
+
"""
|
230 |
+
x: [b, h, l, l]
|
231 |
+
ret: [b, h, l, 2*l-1]
|
232 |
+
"""
|
233 |
+
batch, heads, length, _ = x.size()
|
234 |
+
# padd along column
|
235 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
236 |
+
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
|
237 |
+
# add 0's in the beginning that will skew the elements after reshape
|
238 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
239 |
+
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
240 |
+
return x_final
|
241 |
+
|
242 |
+
def _attention_bias_proximal(self, length):
|
243 |
+
"""Bias for self-attention to encourage attention to close positions.
|
244 |
+
Args:
|
245 |
+
length: an integer scalar.
|
246 |
+
Returns:
|
247 |
+
a Tensor with shape [1, 1, length, length]
|
248 |
+
"""
|
249 |
+
r = torch.arange(length, dtype=torch.float32)
|
250 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
251 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
252 |
+
|
253 |
+
|
254 |
+
class FFN(nn.Module):
|
255 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
|
256 |
+
super().__init__()
|
257 |
+
self.in_channels = in_channels
|
258 |
+
self.out_channels = out_channels
|
259 |
+
self.filter_channels = filter_channels
|
260 |
+
self.kernel_size = kernel_size
|
261 |
+
self.p_dropout = p_dropout
|
262 |
+
self.activation = activation
|
263 |
+
self.causal = causal
|
264 |
+
|
265 |
+
if causal:
|
266 |
+
self.padding = self._causal_padding
|
267 |
+
else:
|
268 |
+
self.padding = self._same_padding
|
269 |
+
|
270 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
271 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
272 |
+
self.drop = nn.Dropout(p_dropout)
|
273 |
+
|
274 |
+
def forward(self, x, x_mask):
|
275 |
+
x = self.conv_1(self.padding(x * x_mask))
|
276 |
+
if self.activation == "gelu":
|
277 |
+
x = x * torch.sigmoid(1.702 * x)
|
278 |
+
else:
|
279 |
+
x = torch.relu(x)
|
280 |
+
x = self.drop(x)
|
281 |
+
x = self.conv_2(self.padding(x * x_mask))
|
282 |
+
return x * x_mask
|
283 |
+
|
284 |
+
def _causal_padding(self, x):
|
285 |
+
if self.kernel_size == 1:
|
286 |
+
return x
|
287 |
+
pad_l = self.kernel_size - 1
|
288 |
+
pad_r = 0
|
289 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
290 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
291 |
+
return x
|
292 |
+
|
293 |
+
def _same_padding(self, x):
|
294 |
+
if self.kernel_size == 1:
|
295 |
+
return x
|
296 |
+
pad_l = (self.kernel_size - 1) // 2
|
297 |
+
pad_r = self.kernel_size // 2
|
298 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
299 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
300 |
+
return x
|
commons.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
import torch.jit
|
4 |
+
|
5 |
+
|
6 |
+
def script_method(fn, _rcb=None):
|
7 |
+
return fn
|
8 |
+
|
9 |
+
|
10 |
+
def script(obj, optimize=True, _frames_up=0, _rcb=None):
|
11 |
+
return obj
|
12 |
+
|
13 |
+
|
14 |
+
torch.jit.script_method = script_method
|
15 |
+
torch.jit.script = script
|
16 |
+
|
17 |
+
|
18 |
+
def init_weights(m, mean=0.0, std=0.01):
|
19 |
+
classname = m.__class__.__name__
|
20 |
+
if classname.find("Conv") != -1:
|
21 |
+
m.weight.data.normal_(mean, std)
|
22 |
+
|
23 |
+
|
24 |
+
def get_padding(kernel_size, dilation=1):
|
25 |
+
return int((kernel_size*dilation - dilation)/2)
|
26 |
+
|
27 |
+
|
28 |
+
def intersperse(lst, item):
|
29 |
+
result = [item] * (len(lst) * 2 + 1)
|
30 |
+
result[1::2] = lst
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
def slice_segments(x, ids_str, segment_size=4):
|
35 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
36 |
+
for i in range(x.size(0)):
|
37 |
+
idx_str = ids_str[i]
|
38 |
+
idx_end = idx_str + segment_size
|
39 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
40 |
+
return ret
|
41 |
+
|
42 |
+
|
43 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
44 |
+
b, d, t = x.size()
|
45 |
+
if x_lengths is None:
|
46 |
+
x_lengths = t
|
47 |
+
ids_str_max = x_lengths - segment_size + 1
|
48 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
49 |
+
ret = slice_segments(x, ids_str, segment_size)
|
50 |
+
return ret, ids_str
|
51 |
+
|
52 |
+
|
53 |
+
def subsequent_mask(length):
|
54 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
55 |
+
return mask
|
56 |
+
|
57 |
+
|
58 |
+
@torch.jit.script
|
59 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
60 |
+
n_channels_int = n_channels[0]
|
61 |
+
in_act = input_a + input_b
|
62 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
63 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
64 |
+
acts = t_act * s_act
|
65 |
+
return acts
|
66 |
+
|
67 |
+
|
68 |
+
def convert_pad_shape(pad_shape):
|
69 |
+
l = pad_shape[::-1]
|
70 |
+
pad_shape = [item for sublist in l for item in sublist]
|
71 |
+
return pad_shape
|
72 |
+
|
73 |
+
|
74 |
+
def sequence_mask(length, max_length=None):
|
75 |
+
if max_length is None:
|
76 |
+
max_length = length.max()
|
77 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
78 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
79 |
+
|
80 |
+
|
81 |
+
def generate_path(duration, mask):
|
82 |
+
"""
|
83 |
+
duration: [b, 1, t_x]
|
84 |
+
mask: [b, 1, t_y, t_x]
|
85 |
+
"""
|
86 |
+
device = duration.device
|
87 |
+
|
88 |
+
b, _, t_y, t_x = mask.shape
|
89 |
+
cum_duration = torch.cumsum(duration, -1)
|
90 |
+
|
91 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
92 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
93 |
+
path = path.view(b, t_x, t_y)
|
94 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
95 |
+
path = path.unsqueeze(1).transpose(2,3) * mask
|
96 |
+
return path
|
hf_models/download.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git clone https://huggingface.co/Salesforce/blip-image-captioning-base
|
2 |
+
|
3 |
+
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
4 |
+
|
5 |
+
git clone https://huggingface.co/runwayml/stable-diffusion-inpainting
|
6 |
+
|
7 |
+
git clone https://huggingface.co/CIDAS/clipseg-rd64-refined
|
8 |
+
|
9 |
+
git clone https://huggingface.co/timbrooks/instruct-pix2pix
|
10 |
+
|
11 |
+
git clone https://huggingface.co/Salesforce/blip-vqa-base
|
12 |
+
|
13 |
+
git clone https://huggingface.co/lllyasviel/ControlNet
|
14 |
+
|
15 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-canny
|
16 |
+
|
17 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-seg
|
18 |
+
|
19 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-scribble
|
20 |
+
|
21 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-normal
|
22 |
+
|
23 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-mlsd
|
24 |
+
|
25 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-depth
|
26 |
+
|
27 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-hed
|
28 |
+
|
29 |
+
git clone https://huggingface.co/lllyasviel/sd-controlnet-openpose
|
30 |
+
|
31 |
+
git clone https://huggingface.co/openmmlab/upernet-convnext-small
|
32 |
+
|
33 |
+
git clone https://huggingface.co/Intel/dpt-hybrid-midas
|
hubert_model.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
9 |
+
|
10 |
+
class Hubert(nn.Module):
|
11 |
+
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
|
12 |
+
super().__init__()
|
13 |
+
self._mask = mask
|
14 |
+
self.feature_extractor = FeatureExtractor()
|
15 |
+
self.feature_projection = FeatureProjection()
|
16 |
+
self.positional_embedding = PositionalConvEmbedding()
|
17 |
+
self.norm = nn.LayerNorm(768)
|
18 |
+
self.dropout = nn.Dropout(0.1)
|
19 |
+
self.encoder = TransformerEncoder(
|
20 |
+
nn.TransformerEncoderLayer(
|
21 |
+
768, 12, 3072, activation="gelu", batch_first=True
|
22 |
+
),
|
23 |
+
12,
|
24 |
+
)
|
25 |
+
self.proj = nn.Linear(768, 256)
|
26 |
+
|
27 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
|
28 |
+
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
|
29 |
+
|
30 |
+
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
31 |
+
mask = None
|
32 |
+
if self.training and self._mask:
|
33 |
+
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
|
34 |
+
x[mask] = self.masked_spec_embed.to(x.dtype)
|
35 |
+
return x, mask
|
36 |
+
|
37 |
+
def encode(
|
38 |
+
self, x: torch.Tensor, layer: Optional[int] = None
|
39 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40 |
+
x = self.feature_extractor(x)
|
41 |
+
x = self.feature_projection(x.transpose(1, 2))
|
42 |
+
x, mask = self.mask(x)
|
43 |
+
x = x + self.positional_embedding(x)
|
44 |
+
x = self.dropout(self.norm(x))
|
45 |
+
x = self.encoder(x, output_layer=layer)
|
46 |
+
return x, mask
|
47 |
+
|
48 |
+
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
49 |
+
logits = torch.cosine_similarity(
|
50 |
+
x.unsqueeze(2),
|
51 |
+
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
|
52 |
+
dim=-1,
|
53 |
+
)
|
54 |
+
return logits / 0.1
|
55 |
+
|
56 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
57 |
+
x, mask = self.encode(x)
|
58 |
+
x = self.proj(x)
|
59 |
+
logits = self.logits(x)
|
60 |
+
return logits, mask
|
61 |
+
|
62 |
+
|
63 |
+
class HubertSoft(Hubert):
|
64 |
+
def __init__(self):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def units(self, wav: torch.Tensor) -> torch.Tensor:
|
69 |
+
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
70 |
+
x, _ = self.encode(wav)
|
71 |
+
return self.proj(x)
|
72 |
+
|
73 |
+
|
74 |
+
class FeatureExtractor(nn.Module):
|
75 |
+
def __init__(self):
|
76 |
+
super().__init__()
|
77 |
+
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
|
78 |
+
self.norm0 = nn.GroupNorm(512, 512)
|
79 |
+
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
80 |
+
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
81 |
+
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
82 |
+
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
83 |
+
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
84 |
+
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
85 |
+
|
86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
87 |
+
x = F.gelu(self.norm0(self.conv0(x)))
|
88 |
+
x = F.gelu(self.conv1(x))
|
89 |
+
x = F.gelu(self.conv2(x))
|
90 |
+
x = F.gelu(self.conv3(x))
|
91 |
+
x = F.gelu(self.conv4(x))
|
92 |
+
x = F.gelu(self.conv5(x))
|
93 |
+
x = F.gelu(self.conv6(x))
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class FeatureProjection(nn.Module):
|
98 |
+
def __init__(self):
|
99 |
+
super().__init__()
|
100 |
+
self.norm = nn.LayerNorm(512)
|
101 |
+
self.projection = nn.Linear(512, 768)
|
102 |
+
self.dropout = nn.Dropout(0.1)
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
105 |
+
x = self.norm(x)
|
106 |
+
x = self.projection(x)
|
107 |
+
x = self.dropout(x)
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class PositionalConvEmbedding(nn.Module):
|
112 |
+
def __init__(self):
|
113 |
+
super().__init__()
|
114 |
+
self.conv = nn.Conv1d(
|
115 |
+
768,
|
116 |
+
768,
|
117 |
+
kernel_size=128,
|
118 |
+
padding=128 // 2,
|
119 |
+
groups=16,
|
120 |
+
)
|
121 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
122 |
+
|
123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
x = self.conv(x.transpose(1, 2))
|
125 |
+
x = F.gelu(x[:, :, :-1])
|
126 |
+
return x.transpose(1, 2)
|
127 |
+
|
128 |
+
|
129 |
+
class TransformerEncoder(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
|
132 |
+
) -> None:
|
133 |
+
super(TransformerEncoder, self).__init__()
|
134 |
+
self.layers = nn.ModuleList(
|
135 |
+
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
|
136 |
+
)
|
137 |
+
self.num_layers = num_layers
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self,
|
141 |
+
src: torch.Tensor,
|
142 |
+
mask: torch.Tensor = None,
|
143 |
+
src_key_padding_mask: torch.Tensor = None,
|
144 |
+
output_layer: Optional[int] = None,
|
145 |
+
) -> torch.Tensor:
|
146 |
+
output = src
|
147 |
+
for layer in self.layers[:output_layer]:
|
148 |
+
output = layer(
|
149 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
150 |
+
)
|
151 |
+
return output
|
152 |
+
|
153 |
+
|
154 |
+
def _compute_mask(
|
155 |
+
shape: Tuple[int, int],
|
156 |
+
mask_prob: float,
|
157 |
+
mask_length: int,
|
158 |
+
device: torch.device,
|
159 |
+
min_masks: int = 0,
|
160 |
+
) -> torch.Tensor:
|
161 |
+
batch_size, sequence_length = shape
|
162 |
+
|
163 |
+
if mask_length < 1:
|
164 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
165 |
+
|
166 |
+
if mask_length > sequence_length:
|
167 |
+
raise ValueError(
|
168 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
169 |
+
)
|
170 |
+
|
171 |
+
# compute number of masked spans in batch
|
172 |
+
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
|
173 |
+
num_masked_spans = max(num_masked_spans, min_masks)
|
174 |
+
|
175 |
+
# make sure num masked indices <= sequence_length
|
176 |
+
if num_masked_spans * mask_length > sequence_length:
|
177 |
+
num_masked_spans = sequence_length // mask_length
|
178 |
+
|
179 |
+
# SpecAugment mask to fill
|
180 |
+
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
181 |
+
|
182 |
+
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
183 |
+
uniform_dist = torch.ones(
|
184 |
+
(batch_size, sequence_length - (mask_length - 1)), device=device
|
185 |
+
)
|
186 |
+
|
187 |
+
# get random indices to mask
|
188 |
+
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
|
189 |
+
|
190 |
+
# expand masked indices to masked spans
|
191 |
+
mask_indices = (
|
192 |
+
mask_indices.unsqueeze(dim=-1)
|
193 |
+
.expand((batch_size, num_masked_spans, mask_length))
|
194 |
+
.reshape(batch_size, num_masked_spans * mask_length)
|
195 |
+
)
|
196 |
+
offsets = (
|
197 |
+
torch.arange(mask_length, device=device)[None, None, :]
|
198 |
+
.expand((batch_size, num_masked_spans, mask_length))
|
199 |
+
.reshape(batch_size, num_masked_spans * mask_length)
|
200 |
+
)
|
201 |
+
mask_idxs = mask_indices + offsets
|
202 |
+
|
203 |
+
# scatter indices to mask
|
204 |
+
mask = mask.scatter(1, mask_idxs, True)
|
205 |
+
|
206 |
+
return mask
|
207 |
+
|
208 |
+
|
209 |
+
def hubert_soft(
|
210 |
+
path: str
|
211 |
+
) -> HubertSoft:
|
212 |
+
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
213 |
+
Args:
|
214 |
+
path (str): path of a pretrained model
|
215 |
+
"""
|
216 |
+
hubert = HubertSoft()
|
217 |
+
checkpoint = torch.load(path)
|
218 |
+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
219 |
+
hubert.load_state_dict(checkpoint)
|
220 |
+
hubert.eval()
|
221 |
+
return hubert
|
image/01cc9083.png
ADDED
image/1d988a81.png
ADDED
image/307ade76.png
ADDED
image/5ebacb6a.png
ADDED
image/cdb4d5e5.png
ADDED
mel_processing.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data
|
3 |
+
from librosa.filters import mel as librosa_mel_fn
|
4 |
+
|
5 |
+
MAX_WAV_VALUE = 32768.0
|
6 |
+
|
7 |
+
|
8 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
9 |
+
"""
|
10 |
+
PARAMS
|
11 |
+
------
|
12 |
+
C: compression factor
|
13 |
+
"""
|
14 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
15 |
+
|
16 |
+
|
17 |
+
def dynamic_range_decompression_torch(x, C=1):
|
18 |
+
"""
|
19 |
+
PARAMS
|
20 |
+
------
|
21 |
+
C: compression factor used to compress
|
22 |
+
"""
|
23 |
+
return torch.exp(x) / C
|
24 |
+
|
25 |
+
|
26 |
+
def spectral_normalize_torch(magnitudes):
|
27 |
+
output = dynamic_range_compression_torch(magnitudes)
|
28 |
+
return output
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_de_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
mel_basis = {}
|
37 |
+
hann_window = {}
|
38 |
+
|
39 |
+
|
40 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
41 |
+
if torch.min(y) < -1.:
|
42 |
+
print('min value is ', torch.min(y))
|
43 |
+
if torch.max(y) > 1.:
|
44 |
+
print('max value is ', torch.max(y))
|
45 |
+
|
46 |
+
global hann_window
|
47 |
+
dtype_device = str(y.dtype) + '_' + str(y.device)
|
48 |
+
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
49 |
+
if wnsize_dtype_device not in hann_window:
|
50 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
51 |
+
|
52 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
53 |
+
y = y.squeeze(1)
|
54 |
+
|
55 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
56 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
57 |
+
|
58 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
59 |
+
return spec
|
60 |
+
|
61 |
+
|
62 |
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
63 |
+
global mel_basis
|
64 |
+
dtype_device = str(spec.dtype) + '_' + str(spec.device)
|
65 |
+
fmax_dtype_device = str(fmax) + '_' + dtype_device
|
66 |
+
if fmax_dtype_device not in mel_basis:
|
67 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
68 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
69 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
70 |
+
spec = spectral_normalize_torch(spec)
|
71 |
+
return spec
|
72 |
+
|
73 |
+
|
74 |
+
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
75 |
+
if torch.min(y) < -1.:
|
76 |
+
print('min value is ', torch.min(y))
|
77 |
+
if torch.max(y) > 1.:
|
78 |
+
print('max value is ', torch.max(y))
|
79 |
+
|
80 |
+
global mel_basis, hann_window
|
81 |
+
dtype_device = str(y.dtype) + '_' + str(y.device)
|
82 |
+
fmax_dtype_device = str(fmax) + '_' + dtype_device
|
83 |
+
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
84 |
+
if fmax_dtype_device not in mel_basis:
|
85 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
86 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
87 |
+
if wnsize_dtype_device not in hann_window:
|
88 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
89 |
+
|
90 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
91 |
+
y = y.squeeze(1)
|
92 |
+
|
93 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
94 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
95 |
+
|
96 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
97 |
+
|
98 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
99 |
+
spec = spectral_normalize_torch(spec)
|
100 |
+
|
101 |
+
return spec
|
models.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
import commons
|
7 |
+
import vits_modules as modules
|
8 |
+
import attentions
|
9 |
+
|
10 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
11 |
+
from torch.nn.utils import weight_norm
|
12 |
+
from commons import init_weights
|
13 |
+
|
14 |
+
|
15 |
+
class StochasticDurationPredictor(nn.Module):
|
16 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
17 |
+
super().__init__()
|
18 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
19 |
+
self.in_channels = in_channels
|
20 |
+
self.filter_channels = filter_channels
|
21 |
+
self.kernel_size = kernel_size
|
22 |
+
self.p_dropout = p_dropout
|
23 |
+
self.n_flows = n_flows
|
24 |
+
self.gin_channels = gin_channels
|
25 |
+
|
26 |
+
self.log_flow = modules.Log()
|
27 |
+
self.flows = nn.ModuleList()
|
28 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
29 |
+
for i in range(n_flows):
|
30 |
+
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
31 |
+
self.flows.append(modules.Flip())
|
32 |
+
|
33 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
34 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
35 |
+
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
36 |
+
self.post_flows = nn.ModuleList()
|
37 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
38 |
+
for i in range(4):
|
39 |
+
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
40 |
+
self.post_flows.append(modules.Flip())
|
41 |
+
|
42 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
43 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
44 |
+
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
45 |
+
if gin_channels != 0:
|
46 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
47 |
+
|
48 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
49 |
+
x = torch.detach(x)
|
50 |
+
x = self.pre(x)
|
51 |
+
if g is not None:
|
52 |
+
g = torch.detach(g)
|
53 |
+
x = x + self.cond(g)
|
54 |
+
x = self.convs(x, x_mask)
|
55 |
+
x = self.proj(x) * x_mask
|
56 |
+
|
57 |
+
if not reverse:
|
58 |
+
flows = self.flows
|
59 |
+
assert w is not None
|
60 |
+
|
61 |
+
logdet_tot_q = 0
|
62 |
+
h_w = self.post_pre(w)
|
63 |
+
h_w = self.post_convs(h_w, x_mask)
|
64 |
+
h_w = self.post_proj(h_w) * x_mask
|
65 |
+
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
66 |
+
z_q = e_q
|
67 |
+
for flow in self.post_flows:
|
68 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
69 |
+
logdet_tot_q += logdet_q
|
70 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
71 |
+
u = torch.sigmoid(z_u) * x_mask
|
72 |
+
z0 = (w - u) * x_mask
|
73 |
+
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
74 |
+
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
75 |
+
|
76 |
+
logdet_tot = 0
|
77 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
78 |
+
logdet_tot += logdet
|
79 |
+
z = torch.cat([z0, z1], 1)
|
80 |
+
for flow in flows:
|
81 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
82 |
+
logdet_tot = logdet_tot + logdet
|
83 |
+
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
84 |
+
return nll + logq # [b]
|
85 |
+
else:
|
86 |
+
flows = list(reversed(self.flows))
|
87 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
88 |
+
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
89 |
+
for flow in flows:
|
90 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
91 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
92 |
+
logw = z0
|
93 |
+
return logw
|
94 |
+
|
95 |
+
|
96 |
+
class DurationPredictor(nn.Module):
|
97 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.in_channels = in_channels
|
101 |
+
self.filter_channels = filter_channels
|
102 |
+
self.kernel_size = kernel_size
|
103 |
+
self.p_dropout = p_dropout
|
104 |
+
self.gin_channels = gin_channels
|
105 |
+
|
106 |
+
self.drop = nn.Dropout(p_dropout)
|
107 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
108 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
109 |
+
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
110 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
111 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
112 |
+
|
113 |
+
if gin_channels != 0:
|
114 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
115 |
+
|
116 |
+
def forward(self, x, x_mask, g=None):
|
117 |
+
x = torch.detach(x)
|
118 |
+
if g is not None:
|
119 |
+
g = torch.detach(g)
|
120 |
+
x = x + self.cond(g)
|
121 |
+
x = self.conv_1(x * x_mask)
|
122 |
+
x = torch.relu(x)
|
123 |
+
x = self.norm_1(x)
|
124 |
+
x = self.drop(x)
|
125 |
+
x = self.conv_2(x * x_mask)
|
126 |
+
x = torch.relu(x)
|
127 |
+
x = self.norm_2(x)
|
128 |
+
x = self.drop(x)
|
129 |
+
x = self.proj(x * x_mask)
|
130 |
+
return x * x_mask
|
131 |
+
|
132 |
+
|
133 |
+
class TextEncoder(nn.Module):
|
134 |
+
def __init__(self,
|
135 |
+
n_vocab,
|
136 |
+
out_channels,
|
137 |
+
hidden_channels,
|
138 |
+
filter_channels,
|
139 |
+
n_heads,
|
140 |
+
n_layers,
|
141 |
+
kernel_size,
|
142 |
+
p_dropout,
|
143 |
+
emotion_embedding):
|
144 |
+
super().__init__()
|
145 |
+
self.n_vocab = n_vocab
|
146 |
+
self.out_channels = out_channels
|
147 |
+
self.hidden_channels = hidden_channels
|
148 |
+
self.filter_channels = filter_channels
|
149 |
+
self.n_heads = n_heads
|
150 |
+
self.n_layers = n_layers
|
151 |
+
self.kernel_size = kernel_size
|
152 |
+
self.p_dropout = p_dropout
|
153 |
+
self.emotion_embedding = emotion_embedding
|
154 |
+
|
155 |
+
if self.n_vocab!=0:
|
156 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
157 |
+
if emotion_embedding:
|
158 |
+
self.emo_proj = nn.Linear(1024, hidden_channels)
|
159 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
160 |
+
|
161 |
+
self.encoder = attentions.Encoder(
|
162 |
+
hidden_channels,
|
163 |
+
filter_channels,
|
164 |
+
n_heads,
|
165 |
+
n_layers,
|
166 |
+
kernel_size,
|
167 |
+
p_dropout)
|
168 |
+
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
169 |
+
|
170 |
+
def forward(self, x, x_lengths, emotion_embedding=None):
|
171 |
+
if self.n_vocab!=0:
|
172 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
173 |
+
if emotion_embedding is not None:
|
174 |
+
x = x + self.emo_proj(emotion_embedding.unsqueeze(1))
|
175 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
176 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
177 |
+
|
178 |
+
x = self.encoder(x * x_mask, x_mask)
|
179 |
+
stats = self.proj(x) * x_mask
|
180 |
+
|
181 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
182 |
+
return x, m, logs, x_mask
|
183 |
+
|
184 |
+
|
185 |
+
class ResidualCouplingBlock(nn.Module):
|
186 |
+
def __init__(self,
|
187 |
+
channels,
|
188 |
+
hidden_channels,
|
189 |
+
kernel_size,
|
190 |
+
dilation_rate,
|
191 |
+
n_layers,
|
192 |
+
n_flows=4,
|
193 |
+
gin_channels=0):
|
194 |
+
super().__init__()
|
195 |
+
self.channels = channels
|
196 |
+
self.hidden_channels = hidden_channels
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.dilation_rate = dilation_rate
|
199 |
+
self.n_layers = n_layers
|
200 |
+
self.n_flows = n_flows
|
201 |
+
self.gin_channels = gin_channels
|
202 |
+
|
203 |
+
self.flows = nn.ModuleList()
|
204 |
+
for i in range(n_flows):
|
205 |
+
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
206 |
+
self.flows.append(modules.Flip())
|
207 |
+
|
208 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
209 |
+
if not reverse:
|
210 |
+
for flow in self.flows:
|
211 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
212 |
+
else:
|
213 |
+
for flow in reversed(self.flows):
|
214 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class PosteriorEncoder(nn.Module):
|
219 |
+
def __init__(self,
|
220 |
+
in_channels,
|
221 |
+
out_channels,
|
222 |
+
hidden_channels,
|
223 |
+
kernel_size,
|
224 |
+
dilation_rate,
|
225 |
+
n_layers,
|
226 |
+
gin_channels=0):
|
227 |
+
super().__init__()
|
228 |
+
self.in_channels = in_channels
|
229 |
+
self.out_channels = out_channels
|
230 |
+
self.hidden_channels = hidden_channels
|
231 |
+
self.kernel_size = kernel_size
|
232 |
+
self.dilation_rate = dilation_rate
|
233 |
+
self.n_layers = n_layers
|
234 |
+
self.gin_channels = gin_channels
|
235 |
+
|
236 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
237 |
+
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
238 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
239 |
+
|
240 |
+
def forward(self, x, x_lengths, g=None):
|
241 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
242 |
+
x = self.pre(x) * x_mask
|
243 |
+
x = self.enc(x, x_mask, g=g)
|
244 |
+
stats = self.proj(x) * x_mask
|
245 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
246 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
247 |
+
return z, m, logs, x_mask
|
248 |
+
|
249 |
+
|
250 |
+
class Generator(torch.nn.Module):
|
251 |
+
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
252 |
+
super(Generator, self).__init__()
|
253 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
254 |
+
self.num_upsamples = len(upsample_rates)
|
255 |
+
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
256 |
+
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
257 |
+
|
258 |
+
self.ups = nn.ModuleList()
|
259 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
260 |
+
self.ups.append(weight_norm(
|
261 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
262 |
+
k, u, padding=(k-u)//2)))
|
263 |
+
|
264 |
+
self.resblocks = nn.ModuleList()
|
265 |
+
for i in range(len(self.ups)):
|
266 |
+
ch = upsample_initial_channel//(2**(i+1))
|
267 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
268 |
+
self.resblocks.append(resblock(ch, k, d))
|
269 |
+
|
270 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
271 |
+
self.ups.apply(init_weights)
|
272 |
+
|
273 |
+
if gin_channels != 0:
|
274 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
275 |
+
|
276 |
+
def forward(self, x, g=None):
|
277 |
+
x = self.conv_pre(x)
|
278 |
+
if g is not None:
|
279 |
+
x = x + self.cond(g)
|
280 |
+
|
281 |
+
for i in range(self.num_upsamples):
|
282 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
283 |
+
x = self.ups[i](x)
|
284 |
+
xs = None
|
285 |
+
for j in range(self.num_kernels):
|
286 |
+
if xs is None:
|
287 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
288 |
+
else:
|
289 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
290 |
+
x = xs / self.num_kernels
|
291 |
+
x = F.leaky_relu(x)
|
292 |
+
x = self.conv_post(x)
|
293 |
+
x = torch.tanh(x)
|
294 |
+
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class SynthesizerTrn(nn.Module):
|
299 |
+
"""
|
300 |
+
Synthesizer for Training
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self,
|
304 |
+
n_vocab,
|
305 |
+
spec_channels,
|
306 |
+
segment_size,
|
307 |
+
inter_channels,
|
308 |
+
hidden_channels,
|
309 |
+
filter_channels,
|
310 |
+
n_heads,
|
311 |
+
n_layers,
|
312 |
+
kernel_size,
|
313 |
+
p_dropout,
|
314 |
+
resblock,
|
315 |
+
resblock_kernel_sizes,
|
316 |
+
resblock_dilation_sizes,
|
317 |
+
upsample_rates,
|
318 |
+
upsample_initial_channel,
|
319 |
+
upsample_kernel_sizes,
|
320 |
+
n_speakers=0,
|
321 |
+
gin_channels=0,
|
322 |
+
use_sdp=True,
|
323 |
+
emotion_embedding=False,
|
324 |
+
**kwargs):
|
325 |
+
|
326 |
+
super().__init__()
|
327 |
+
self.n_vocab = n_vocab
|
328 |
+
self.spec_channels = spec_channels
|
329 |
+
self.inter_channels = inter_channels
|
330 |
+
self.hidden_channels = hidden_channels
|
331 |
+
self.filter_channels = filter_channels
|
332 |
+
self.n_heads = n_heads
|
333 |
+
self.n_layers = n_layers
|
334 |
+
self.kernel_size = kernel_size
|
335 |
+
self.p_dropout = p_dropout
|
336 |
+
self.resblock = resblock
|
337 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
338 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
339 |
+
self.upsample_rates = upsample_rates
|
340 |
+
self.upsample_initial_channel = upsample_initial_channel
|
341 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
342 |
+
self.segment_size = segment_size
|
343 |
+
self.n_speakers = n_speakers
|
344 |
+
self.gin_channels = gin_channels
|
345 |
+
|
346 |
+
self.use_sdp = use_sdp
|
347 |
+
|
348 |
+
self.enc_p = TextEncoder(n_vocab,
|
349 |
+
inter_channels,
|
350 |
+
hidden_channels,
|
351 |
+
filter_channels,
|
352 |
+
n_heads,
|
353 |
+
n_layers,
|
354 |
+
kernel_size,
|
355 |
+
p_dropout,
|
356 |
+
emotion_embedding)
|
357 |
+
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
358 |
+
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
359 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
360 |
+
|
361 |
+
if use_sdp:
|
362 |
+
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
363 |
+
else:
|
364 |
+
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
365 |
+
|
366 |
+
if n_speakers > 1:
|
367 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
368 |
+
|
369 |
+
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None):
|
370 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
|
371 |
+
if self.n_speakers > 0:
|
372 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
373 |
+
else:
|
374 |
+
g = None
|
375 |
+
|
376 |
+
if self.use_sdp:
|
377 |
+
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
378 |
+
else:
|
379 |
+
logw = self.dp(x, x_mask, g=g)
|
380 |
+
w = torch.exp(logw) * x_mask * length_scale
|
381 |
+
w_ceil = torch.ceil(w)
|
382 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
383 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
384 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
385 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
386 |
+
|
387 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
388 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
389 |
+
|
390 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
391 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
392 |
+
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
393 |
+
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
394 |
+
|
395 |
+
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
396 |
+
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
|
397 |
+
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
398 |
+
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
399 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
400 |
+
z_p = self.flow(z, y_mask, g=g_src)
|
401 |
+
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
402 |
+
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
403 |
+
return o_hat, y_mask, (z, z_p, z_hat)
|
404 |
+
|
modules/__init__.py
ADDED
File without changes
|
modules/controlnet_canny.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Canny:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Canny")
|
6 |
+
self.low_threshold = 100
|
7 |
+
self.high_threshold = 200
|
8 |
+
|
9 |
+
@prompts(name="Edge Detection On Image",
|
10 |
+
description="useful when you want to detect the edge of the image. "
|
11 |
+
"like: detect the edges of this image, or canny detection on image, "
|
12 |
+
"or perform edge detection on this image, or detect the canny image of this image. "
|
13 |
+
"The input to this tool should be a string, representing the image_path")
|
14 |
+
def inference(self, inputs):
|
15 |
+
image = Image.open(inputs)
|
16 |
+
image = np.array(image)
|
17 |
+
canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
|
18 |
+
canny = canny[:, :, None]
|
19 |
+
canny = np.concatenate([canny, canny, canny], axis=2)
|
20 |
+
canny = Image.fromarray(canny)
|
21 |
+
updated_image_path = get_new_image_name(inputs, func_name="edge")
|
22 |
+
canny.save(updated_image_path)
|
23 |
+
print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}")
|
24 |
+
return updated_image_path
|
25 |
+
|
26 |
+
class CannyText2Image:
|
27 |
+
def __init__(self, device, pretrained_model_dir):
|
28 |
+
print("Initializing CannyText2Image to %s" % device)
|
29 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
30 |
+
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-canny",
|
31 |
+
torch_dtype=self.torch_dtype)
|
32 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
33 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
34 |
+
torch_dtype=self.torch_dtype)
|
35 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
36 |
+
self.pipe.to(device)
|
37 |
+
self.seed = -1
|
38 |
+
self.a_prompt = 'best quality, extremely detailed'
|
39 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
40 |
+
'fewer digits, cropped, worst quality, low quality'
|
41 |
+
|
42 |
+
@prompts(name="Generate Image Condition On Canny Image",
|
43 |
+
description="useful when you want to generate a new real image from both the user desciption and a canny image."
|
44 |
+
" like: generate a real image of a object or something from this canny image,"
|
45 |
+
" or generate a new real image of a object or something from this edge image. "
|
46 |
+
"The input to this tool should be a comma seperated string of two, "
|
47 |
+
"representing the image_path and the user description. ")
|
48 |
+
def inference(self, inputs):
|
49 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
50 |
+
image = Image.open(image_path)
|
51 |
+
self.seed = random.randint(0, 65535)
|
52 |
+
seed_everything(self.seed)
|
53 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
54 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
55 |
+
guidance_scale=9.0).images[0]
|
56 |
+
updated_image_path = get_new_image_name(image_path, func_name="canny2image")
|
57 |
+
image.save(updated_image_path)
|
58 |
+
print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, "
|
59 |
+
f"Output Text: {updated_image_path}")
|
60 |
+
return updated_image_path
|
modules/controlnet_depth.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Depth:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Depth")
|
6 |
+
self.depth_estimator = pipeline('depth-estimation')
|
7 |
+
|
8 |
+
@prompts(name="Predict Depth On Image",
|
9 |
+
description="useful when you want to detect depth of the image. like: generate the depth from this image, "
|
10 |
+
"or detect the depth map on this image, or predict the depth for this image. "
|
11 |
+
"The input to this tool should be a string, representing the image_path")
|
12 |
+
def inference(self, inputs):
|
13 |
+
image = Image.open(inputs)
|
14 |
+
depth = self.depth_estimator(image)['depth']
|
15 |
+
depth = np.array(depth)
|
16 |
+
depth = depth[:, :, None]
|
17 |
+
depth = np.concatenate([depth, depth, depth], axis=2)
|
18 |
+
depth = Image.fromarray(depth)
|
19 |
+
updated_image_path = get_new_image_name(inputs, func_name="depth")
|
20 |
+
depth.save(updated_image_path)
|
21 |
+
print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}")
|
22 |
+
return updated_image_path
|
23 |
+
|
24 |
+
|
25 |
+
class DepthText2Image:
|
26 |
+
def __init__(self, device, pretrained_model_dir):
|
27 |
+
print("Initializing DepthText2Image to %s" % device)
|
28 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
29 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
30 |
+
f"{pretrained_model_dir}/sd-controlnet-depth", torch_dtype=self.torch_dtype)
|
31 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
32 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
33 |
+
torch_dtype=self.torch_dtype)
|
34 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
35 |
+
self.pipe.to(device)
|
36 |
+
self.seed = -1
|
37 |
+
self.a_prompt = 'best quality, extremely detailed'
|
38 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
39 |
+
' fewer digits, cropped, worst quality, low quality'
|
40 |
+
|
41 |
+
@prompts(name="Generate Image Condition On Depth",
|
42 |
+
description="useful when you want to generate a new real image from both the user desciption and depth image. "
|
43 |
+
"like: generate a real image of a object or something from this depth image, "
|
44 |
+
"or generate a new real image of a object or something from the depth map. "
|
45 |
+
"The input to this tool should be a comma seperated string of two, "
|
46 |
+
"representing the image_path and the user description")
|
47 |
+
def inference(self, inputs):
|
48 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
49 |
+
image = Image.open(image_path)
|
50 |
+
self.seed = random.randint(0, 65535)
|
51 |
+
seed_everything(self.seed)
|
52 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
53 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
54 |
+
guidance_scale=9.0).images[0]
|
55 |
+
updated_image_path = get_new_image_name(image_path, func_name="depth2image")
|
56 |
+
image.save(updated_image_path)
|
57 |
+
print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, "
|
58 |
+
f"Output Image: {updated_image_path}")
|
59 |
+
return updated_image_path
|
modules/controlnet_hed.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Hed:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Hed")
|
6 |
+
self.detector = HEDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
|
7 |
+
|
8 |
+
@prompts(name="Hed Detection On Image",
|
9 |
+
description="useful when you want to detect the soft hed boundary of the image. "
|
10 |
+
"like: detect the soft hed boundary of this image, or hed boundary detection on image, "
|
11 |
+
"or peform hed boundary detection on this image, or detect soft hed boundary image of this image. "
|
12 |
+
"The input to this tool should be a string, representing the image_path")
|
13 |
+
def inference(self, inputs):
|
14 |
+
image = Image.open(inputs)
|
15 |
+
hed = self.detector(image)
|
16 |
+
updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
|
17 |
+
hed.save(updated_image_path)
|
18 |
+
print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}")
|
19 |
+
return updated_image_path
|
20 |
+
|
21 |
+
|
22 |
+
class HedText2Image:
|
23 |
+
def __init__(self, device, pretrained_model_dir):
|
24 |
+
print("Initializing HedText2Image to %s" % device)
|
25 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
26 |
+
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-hed",
|
27 |
+
torch_dtype=self.torch_dtype)
|
28 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
29 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
30 |
+
torch_dtype=self.torch_dtype
|
31 |
+
)
|
32 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
33 |
+
self.pipe.to(device)
|
34 |
+
self.seed = -1
|
35 |
+
self.a_prompt = 'best quality, extremely detailed'
|
36 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
37 |
+
'fewer digits, cropped, worst quality, low quality'
|
38 |
+
|
39 |
+
@prompts(name="Generate Image Condition On Soft Hed Boundary Image",
|
40 |
+
description="useful when you want to generate a new real image from both the user desciption "
|
41 |
+
"and a soft hed boundary image. "
|
42 |
+
"like: generate a real image of a object or something from this soft hed boundary image, "
|
43 |
+
"or generate a new real image of a object or something from this hed boundary. "
|
44 |
+
"The input to this tool should be a comma seperated string of two, "
|
45 |
+
"representing the image_path and the user description")
|
46 |
+
def inference(self, inputs):
|
47 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
48 |
+
image = Image.open(image_path)
|
49 |
+
self.seed = random.randint(0, 65535)
|
50 |
+
seed_everything(self.seed)
|
51 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
52 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
53 |
+
guidance_scale=9.0).images[0]
|
54 |
+
updated_image_path = get_new_image_name(image_path, func_name="hed2image")
|
55 |
+
image.save(updated_image_path)
|
56 |
+
print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, "
|
57 |
+
f"Output Image: {updated_image_path}")
|
58 |
+
return updated_image_path
|
modules/controlnet_line.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Line:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Line")
|
6 |
+
self.detector = MLSDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
|
7 |
+
|
8 |
+
@prompts(name="Line Detection On Image",
|
9 |
+
description="useful when you want to detect the straight line of the image. "
|
10 |
+
"like: detect the straight lines of this image, or straight line detection on image, "
|
11 |
+
"or peform straight line detection on this image, or detect the straight line image of this image. "
|
12 |
+
"The input to this tool should be a string, representing the image_path")
|
13 |
+
def inference(self, inputs):
|
14 |
+
image = Image.open(inputs)
|
15 |
+
mlsd = self.detector(image)
|
16 |
+
updated_image_path = get_new_image_name(inputs, func_name="line-of")
|
17 |
+
mlsd.save(updated_image_path)
|
18 |
+
print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}")
|
19 |
+
return updated_image_path
|
20 |
+
|
21 |
+
|
22 |
+
class LineText2Image:
|
23 |
+
def __init__(self, device, pretrained_model_dir):
|
24 |
+
print("Initializing LineText2Image to %s" % device)
|
25 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
26 |
+
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-mlsd",
|
27 |
+
torch_dtype=self.torch_dtype)
|
28 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
29 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
30 |
+
torch_dtype=self.torch_dtype
|
31 |
+
)
|
32 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
33 |
+
self.pipe.to(device)
|
34 |
+
self.seed = -1
|
35 |
+
self.a_prompt = 'best quality, extremely detailed'
|
36 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
37 |
+
'fewer digits, cropped, worst quality, low quality'
|
38 |
+
|
39 |
+
@prompts(name="Generate Image Condition On Line Image",
|
40 |
+
description="useful when you want to generate a new real image from both the user desciption "
|
41 |
+
"and a straight line image. "
|
42 |
+
"like: generate a real image of a object or something from this straight line image, "
|
43 |
+
"or generate a new real image of a object or something from this straight lines. "
|
44 |
+
"The input to this tool should be a comma seperated string of two, "
|
45 |
+
"representing the image_path and the user description. ")
|
46 |
+
def inference(self, inputs):
|
47 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
48 |
+
image = Image.open(image_path)
|
49 |
+
self.seed = random.randint(0, 65535)
|
50 |
+
seed_everything(self.seed)
|
51 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
52 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
53 |
+
guidance_scale=9.0).images[0]
|
54 |
+
updated_image_path = get_new_image_name(image_path, func_name="line2image")
|
55 |
+
image.save(updated_image_path)
|
56 |
+
print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, "
|
57 |
+
f"Output Text: {updated_image_path}")
|
58 |
+
return updated_image_path
|
modules/controlnet_normal.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Normal:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Normal")
|
6 |
+
self.depth_estimator = pipeline("depth-estimation", model=f"{pretrained_model_dir}/dpt-hybrid-midas")
|
7 |
+
self.bg_threhold = 0.4
|
8 |
+
|
9 |
+
@prompts(name="Predict Normal Map On Image",
|
10 |
+
description="useful when you want to detect norm map of the image. "
|
11 |
+
"like: generate normal map from this image, or predict normal map of this image. "
|
12 |
+
"The input to this tool should be a string, representing the image_path")
|
13 |
+
def inference(self, inputs):
|
14 |
+
image = Image.open(inputs)
|
15 |
+
original_size = image.size
|
16 |
+
image = self.depth_estimator(image)['predicted_depth'][0]
|
17 |
+
image = image.numpy()
|
18 |
+
image_depth = image.copy()
|
19 |
+
image_depth -= np.min(image_depth)
|
20 |
+
image_depth /= np.max(image_depth)
|
21 |
+
x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
|
22 |
+
x[image_depth < self.bg_threhold] = 0
|
23 |
+
y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
|
24 |
+
y[image_depth < self.bg_threhold] = 0
|
25 |
+
z = np.ones_like(x) * np.pi * 2.0
|
26 |
+
image = np.stack([x, y, z], axis=2)
|
27 |
+
image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
|
28 |
+
image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
29 |
+
image = Image.fromarray(image)
|
30 |
+
image = image.resize(original_size)
|
31 |
+
updated_image_path = get_new_image_name(inputs, func_name="normal-map")
|
32 |
+
image.save(updated_image_path)
|
33 |
+
print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}")
|
34 |
+
return updated_image_path
|
35 |
+
|
36 |
+
|
37 |
+
class NormalText2Image:
|
38 |
+
def __init__(self, device, pretrained_model_dir):
|
39 |
+
print("Initializing NormalText2Image to %s" % device)
|
40 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
41 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
42 |
+
f"{pretrained_model_dir}/sd-controlnet-normal", torch_dtype=self.torch_dtype)
|
43 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
44 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
45 |
+
torch_dtype=self.torch_dtype)
|
46 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
47 |
+
self.pipe.to(device)
|
48 |
+
self.seed = -1
|
49 |
+
self.a_prompt = 'best quality, extremely detailed'
|
50 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
51 |
+
' fewer digits, cropped, worst quality, low quality'
|
52 |
+
|
53 |
+
@prompts(name="Generate Image Condition On Normal Map",
|
54 |
+
description="useful when you want to generate a new real image from both the user desciption and normal map. "
|
55 |
+
"like: generate a real image of a object or something from this normal map, "
|
56 |
+
"or generate a new real image of a object or something from the normal map. "
|
57 |
+
"The input to this tool should be a comma seperated string of two, "
|
58 |
+
"representing the image_path and the user description")
|
59 |
+
def inference(self, inputs):
|
60 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
61 |
+
image = Image.open(image_path)
|
62 |
+
self.seed = random.randint(0, 65535)
|
63 |
+
seed_everything(self.seed)
|
64 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
65 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
66 |
+
guidance_scale=9.0).images[0]
|
67 |
+
updated_image_path = get_new_image_name(image_path, func_name="normal2image")
|
68 |
+
image.save(updated_image_path)
|
69 |
+
print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, "
|
70 |
+
f"Output Image: {updated_image_path}")
|
71 |
+
return updated_image_path
|
modules/controlnet_pose.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Pose:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Pose")
|
6 |
+
self.detector = OpenposeDetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
|
7 |
+
|
8 |
+
@prompts(name="Pose Detection On Image",
|
9 |
+
description="useful when you want to detect the human pose of the image. "
|
10 |
+
"like: generate human poses of this image, or generate a pose image from this image. "
|
11 |
+
"The input to this tool should be a string, representing the image_path")
|
12 |
+
def inference(self, inputs):
|
13 |
+
image = Image.open(inputs)
|
14 |
+
pose = self.detector(image)
|
15 |
+
updated_image_path = get_new_image_name(inputs, func_name="human-pose")
|
16 |
+
pose.save(updated_image_path)
|
17 |
+
print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
|
18 |
+
return updated_image_path
|
19 |
+
|
20 |
+
|
21 |
+
class PoseText2Image:
|
22 |
+
def __init__(self, device, pretrained_model_dir):
|
23 |
+
print("Initializing PoseText2Image to %s" % device)
|
24 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
25 |
+
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-openpose",
|
26 |
+
torch_dtype=self.torch_dtype)
|
27 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
28 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
29 |
+
torch_dtype=self.torch_dtype)
|
30 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
31 |
+
self.pipe.to(device)
|
32 |
+
self.num_inference_steps = 20
|
33 |
+
self.seed = -1
|
34 |
+
self.unconditional_guidance_scale = 9.0
|
35 |
+
self.a_prompt = 'best quality, extremely detailed'
|
36 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
37 |
+
' fewer digits, cropped, worst quality, low quality'
|
38 |
+
|
39 |
+
@prompts(name="Generate Image Condition On Pose Image",
|
40 |
+
description="useful when you want to generate a new real image from both the user desciption "
|
41 |
+
"and a human pose image. "
|
42 |
+
"like: generate a real image of a human from this human pose image, "
|
43 |
+
"or generate a new real image of a human from this pose. "
|
44 |
+
"The input to this tool should be a comma seperated string of two, "
|
45 |
+
"representing the image_path and the user description")
|
46 |
+
def inference(self, inputs):
|
47 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
48 |
+
image = Image.open(image_path)
|
49 |
+
self.seed = random.randint(0, 65535)
|
50 |
+
seed_everything(self.seed)
|
51 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
52 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
53 |
+
guidance_scale=9.0).images[0]
|
54 |
+
updated_image_path = get_new_image_name(image_path, func_name="pose2image")
|
55 |
+
image.save(updated_image_path)
|
56 |
+
print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, "
|
57 |
+
f"Output Image: {updated_image_path}")
|
58 |
+
return updated_image_path
|
modules/controlnet_scibble.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Scribble:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Scribble")
|
6 |
+
self.detector = HEDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
|
7 |
+
|
8 |
+
@prompts(name="Sketch Detection On Image",
|
9 |
+
description="useful when you want to generate a scribble of the image. "
|
10 |
+
"like: generate a scribble of this image, or generate a sketch from this image, "
|
11 |
+
"detect the sketch from this image. "
|
12 |
+
"The input to this tool should be a string, representing the image_path")
|
13 |
+
def inference(self, inputs):
|
14 |
+
image = Image.open(inputs)
|
15 |
+
scribble = self.detector(image, scribble=True)
|
16 |
+
updated_image_path = get_new_image_name(inputs, func_name="scribble")
|
17 |
+
scribble.save(updated_image_path)
|
18 |
+
print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
|
19 |
+
return updated_image_path
|
20 |
+
|
21 |
+
|
22 |
+
class ScribbleText2Image:
|
23 |
+
def __init__(self, device, pretrained_model_dir):
|
24 |
+
print("Initializing ScribbleText2Image to %s" % device)
|
25 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
26 |
+
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-scribble",
|
27 |
+
torch_dtype=self.torch_dtype)
|
28 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
29 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
30 |
+
torch_dtype=self.torch_dtype
|
31 |
+
)
|
32 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
33 |
+
self.pipe.to(device)
|
34 |
+
self.seed = -1
|
35 |
+
self.a_prompt = 'best quality, extremely detailed'
|
36 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
37 |
+
'fewer digits, cropped, worst quality, low quality'
|
38 |
+
|
39 |
+
@prompts(name="Generate Image Condition On Sketch Image",
|
40 |
+
description="useful when you want to generate a new real image from both the user desciption and "
|
41 |
+
"a scribble image or a sketch image. "
|
42 |
+
"The input to this tool should be a comma seperated string of two, "
|
43 |
+
"representing the image_path and the user description")
|
44 |
+
def inference(self, inputs):
|
45 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
46 |
+
image = Image.open(image_path)
|
47 |
+
self.seed = random.randint(0, 65535)
|
48 |
+
seed_everything(self.seed)
|
49 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
50 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
51 |
+
guidance_scale=9.0).images[0]
|
52 |
+
updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
|
53 |
+
image.save(updated_image_path)
|
54 |
+
print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
|
55 |
+
f"Output Image: {updated_image_path}")
|
56 |
+
return updated_image_path
|
modules/controlnet_seg.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Image2Seg:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Image2Seg")
|
6 |
+
self.image_processor = AutoImageProcessor.from_pretrained(f"{pretrained_model_dir}/upernet-convnext-small")
|
7 |
+
self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained(f"{pretrained_model_dir}/upernet-convnext-small")
|
8 |
+
self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
9 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
10 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
11 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
12 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
13 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
14 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
15 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
16 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
17 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
18 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
19 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
20 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
21 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
22 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
23 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
24 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
25 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
26 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
27 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
28 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
29 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
30 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
31 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
32 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
33 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
34 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
35 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
36 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
37 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
38 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
39 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
40 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
41 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
42 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
43 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
44 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
45 |
+
[102, 255, 0], [92, 0, 255]]
|
46 |
+
|
47 |
+
@prompts(name="Segmentation On Image",
|
48 |
+
description="useful when you want to detect segmentations of the image. "
|
49 |
+
"like: segment this image, or generate segmentations on this image, "
|
50 |
+
"or peform segmentation on this image. "
|
51 |
+
"The input to this tool should be a string, representing the image_path")
|
52 |
+
def inference(self, inputs):
|
53 |
+
image = Image.open(inputs)
|
54 |
+
pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
|
55 |
+
with torch.no_grad():
|
56 |
+
outputs = self.image_segmentor(pixel_values)
|
57 |
+
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
58 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
|
59 |
+
palette = np.array(self.ade_palette)
|
60 |
+
for label, color in enumerate(palette):
|
61 |
+
color_seg[seg == label, :] = color
|
62 |
+
color_seg = color_seg.astype(np.uint8)
|
63 |
+
segmentation = Image.fromarray(color_seg)
|
64 |
+
updated_image_path = get_new_image_name(inputs, func_name="segmentation")
|
65 |
+
segmentation.save(updated_image_path)
|
66 |
+
print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
|
67 |
+
return updated_image_path
|
68 |
+
|
69 |
+
|
70 |
+
class SegText2Image:
|
71 |
+
def __init__(self, device, pretrained_model_dir):
|
72 |
+
print("Initializing SegText2Image to %s" % device)
|
73 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
74 |
+
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-seg",
|
75 |
+
torch_dtype=self.torch_dtype)
|
76 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
77 |
+
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
78 |
+
torch_dtype=self.torch_dtype)
|
79 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
80 |
+
self.pipe.to(device)
|
81 |
+
self.seed = -1
|
82 |
+
self.a_prompt = 'best quality, extremely detailed'
|
83 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
84 |
+
' fewer digits, cropped, worst quality, low quality'
|
85 |
+
|
86 |
+
@prompts(name="Generate Image Condition On Segmentations",
|
87 |
+
description="useful when you want to generate a new real image from both the user desciption and segmentations. "
|
88 |
+
"like: generate a real image of a object or something from this segmentation image, "
|
89 |
+
"or generate a new real image of a object or something from these segmentations. "
|
90 |
+
"The input to this tool should be a comma seperated string of two, "
|
91 |
+
"representing the image_path and the user description")
|
92 |
+
def inference(self, inputs):
|
93 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
94 |
+
image = Image.open(image_path)
|
95 |
+
self.seed = random.randint(0, 65535)
|
96 |
+
seed_everything(self.seed)
|
97 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
98 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
99 |
+
guidance_scale=9.0).images[0]
|
100 |
+
updated_image_path = get_new_image_name(image_path, func_name="segment2image")
|
101 |
+
image.save(updated_image_path)
|
102 |
+
print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, "
|
103 |
+
f"Output Image: {updated_image_path}")
|
104 |
+
return updated_image_path
|
modules/image_captioning.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules.utils import *
|
3 |
+
|
4 |
+
class ImageCaptioning:
|
5 |
+
def __init__(self, device, pretrained_model_dir):
|
6 |
+
print("Initializing ImageCaptioning to %s" % device)
|
7 |
+
self.device = device
|
8 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
9 |
+
self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-image-captioning-base")
|
10 |
+
self.model = BlipForConditionalGeneration.from_pretrained(
|
11 |
+
f"{pretrained_model_dir}/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
|
12 |
+
|
13 |
+
@prompts(name="Get Photo Description",
|
14 |
+
description="useful when you want to know what is inside the photo. receives image_path as input. "
|
15 |
+
"The input to this tool should be a string, representing the image_path. ")
|
16 |
+
def inference(self, image_path):
|
17 |
+
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
|
18 |
+
out = self.model.generate(**inputs)
|
19 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
20 |
+
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
|
21 |
+
return captions
|
modules/image_editing.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class ImageEditing:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing ImageEditing to %s" % device)
|
6 |
+
self.device = device
|
7 |
+
self.mask_former = MaskFormer(device=self.device, pretrained_model_dir=pretrained_model_dir)
|
8 |
+
self.revision = 'fp16' if 'cuda' in device else None
|
9 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
10 |
+
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
11 |
+
f"{pretrained_model_dir}/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
|
12 |
+
|
13 |
+
@prompts(name="Remove Something From The Photo",
|
14 |
+
description="useful when you want to remove and object or something from the photo "
|
15 |
+
"from its description or location. "
|
16 |
+
"The input to this tool should be a comma seperated string of two, "
|
17 |
+
"representing the image_path and the object need to be removed. ")
|
18 |
+
def inference_remove(self, inputs):
|
19 |
+
image_path, to_be_removed_txt = inputs.split(",")
|
20 |
+
return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
|
21 |
+
|
22 |
+
@prompts(name="Replace Something From The Photo",
|
23 |
+
description="useful when you want to replace an object from the object description or "
|
24 |
+
"location with another object from its description. "
|
25 |
+
"The input to this tool should be a comma seperated string of three, "
|
26 |
+
"representing the image_path, the object to be replaced, the object to be replaced with ")
|
27 |
+
def inference_replace(self, inputs):
|
28 |
+
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
|
29 |
+
original_image = Image.open(image_path)
|
30 |
+
original_size = original_image.size
|
31 |
+
mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
|
32 |
+
updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)),
|
33 |
+
mask_image=mask_image.resize((512, 512))).images[0]
|
34 |
+
updated_image_path = get_new_image_name(image_path, func_name="replace-something")
|
35 |
+
updated_image = updated_image.resize(original_size)
|
36 |
+
updated_image.save(updated_image_path)
|
37 |
+
print(
|
38 |
+
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
|
39 |
+
f"Output Image: {updated_image_path}")
|
40 |
+
return updated_image_path
|
modules/instruct_px2pix.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class InstructPix2Pix:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing InstructPix2Pix to %s" % device)
|
6 |
+
self.device = device
|
7 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
8 |
+
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(f"{pretrained_model_dir}/instruct-pix2pix",
|
9 |
+
safety_checker=None,
|
10 |
+
torch_dtype=self.torch_dtype).to(device)
|
11 |
+
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
12 |
+
|
13 |
+
@prompts(name="Instruct Image Using Text",
|
14 |
+
description="useful when you want to the style of the image to be like the text. "
|
15 |
+
"like: make it look like a painting. or make it like a robot. "
|
16 |
+
"The input to this tool should be a comma seperated string of two, "
|
17 |
+
"representing the image_path and the text. ")
|
18 |
+
def inference(self, inputs):
|
19 |
+
"""Change style of image."""
|
20 |
+
print("===>Starting InstructPix2Pix Inference")
|
21 |
+
image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
22 |
+
original_image = Image.open(image_path)
|
23 |
+
image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
|
24 |
+
updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
|
25 |
+
image.save(updated_image_path)
|
26 |
+
print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
|
27 |
+
f"Output Image: {updated_image_path}")
|
28 |
+
return updated_image_path
|
modules/mask_former.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class MaskFormer:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing MaskFormer to %s" % device)
|
6 |
+
self.device = device
|
7 |
+
self.processor = CLIPSegProcessor.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined")
|
8 |
+
self.model = CLIPSegForImageSegmentation.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined").to(device)
|
9 |
+
|
10 |
+
def inference(self, image_path, text):
|
11 |
+
threshold = 0.5
|
12 |
+
min_area = 0.02
|
13 |
+
padding = 20
|
14 |
+
original_image = Image.open(image_path)
|
15 |
+
image = original_image.resize((512, 512))
|
16 |
+
inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
|
17 |
+
with torch.no_grad():
|
18 |
+
outputs = self.model(**inputs)
|
19 |
+
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
|
20 |
+
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
|
21 |
+
if area_ratio < min_area:
|
22 |
+
return None
|
23 |
+
true_indices = np.argwhere(mask)
|
24 |
+
mask_array = np.zeros_like(mask, dtype=bool)
|
25 |
+
for idx in true_indices:
|
26 |
+
padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
|
27 |
+
mask_array[padded_slice] = True
|
28 |
+
visual_mask = (mask_array * 255).astype(np.uint8)
|
29 |
+
image_mask = Image.fromarray(visual_mask)
|
30 |
+
return image_mask.resize(original_image.size)
|
modules/text2img.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class Text2Image:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing Text2Image to %s" % device)
|
6 |
+
self.device = device
|
7 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
8 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(f"{pretrained_model_dir}/stable-diffusion-v1-5",
|
9 |
+
torch_dtype=self.torch_dtype)
|
10 |
+
self.pipe.to(device)
|
11 |
+
self.a_prompt = 'best quality, extremely detailed'
|
12 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
13 |
+
'fewer digits, cropped, worst quality, low quality'
|
14 |
+
|
15 |
+
@prompts(name="Generate Image From User Input Text",
|
16 |
+
description="useful when you want to generate an image from a user input text and save it to a file. "
|
17 |
+
"like: generate an image of an object or something, or generate an image that includes some objects. "
|
18 |
+
"The input to this tool should be a string, representing the text used to generate image. ")
|
19 |
+
def inference(self, text):
|
20 |
+
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
21 |
+
prompt = text + ', ' + self.a_prompt
|
22 |
+
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
|
23 |
+
image.save(image_filename)
|
24 |
+
print(
|
25 |
+
f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
|
26 |
+
return image_filename
|
modules/utils.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
import re
|
8 |
+
import uuid
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
|
14 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
15 |
+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
16 |
+
|
17 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
|
18 |
+
from diffusers import EulerAncestralDiscreteScheduler
|
19 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
20 |
+
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
|
21 |
+
|
22 |
+
from langchain.agents.initialize import initialize_agent
|
23 |
+
from langchain.agents.tools import Tool
|
24 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
25 |
+
from langchain.llms.openai import OpenAI
|
26 |
+
|
27 |
+
# 装饰器
|
28 |
+
def prompts(name, description):
|
29 |
+
def decorator(func):
|
30 |
+
func.name = name
|
31 |
+
func.description = description
|
32 |
+
return func
|
33 |
+
|
34 |
+
return decorator
|
35 |
+
|
36 |
+
# 设置种子
|
37 |
+
def seed_everything(seed):
|
38 |
+
random.seed(seed)
|
39 |
+
np.random.seed(seed)
|
40 |
+
torch.manual_seed(seed)
|
41 |
+
torch.cuda.manual_seed_all(seed)
|
42 |
+
return seed
|
43 |
+
|
44 |
+
# 对话历史截断
|
45 |
+
def cut_dialogue_history(history_memory, keep_last_n_words=500):
|
46 |
+
tokens = history_memory.split()
|
47 |
+
n_tokens = len(tokens)
|
48 |
+
print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
|
49 |
+
if n_tokens < keep_last_n_words:
|
50 |
+
return history_memory
|
51 |
+
else:
|
52 |
+
paragraphs = history_memory.split('\n')
|
53 |
+
last_n_tokens = n_tokens
|
54 |
+
while last_n_tokens >= keep_last_n_words:
|
55 |
+
last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
|
56 |
+
paragraphs = paragraphs[1:]
|
57 |
+
return '\n' + '\n'.join(paragraphs)
|
58 |
+
|
59 |
+
# 获取新图片
|
60 |
+
def get_new_image_name(org_img_name, func_name="update"):
|
61 |
+
head_tail = os.path.split(org_img_name)
|
62 |
+
head = head_tail[0]
|
63 |
+
tail = head_tail[1]
|
64 |
+
name_split = tail.split('.')[0].split('_')
|
65 |
+
this_new_uuid = str(uuid.uuid4())[0:4]
|
66 |
+
if len(name_split) == 1:
|
67 |
+
most_org_file_name = name_split[0]
|
68 |
+
recent_prev_file_name = name_split[0]
|
69 |
+
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
|
70 |
+
else:
|
71 |
+
assert len(name_split) == 4
|
72 |
+
most_org_file_name = name_split[3]
|
73 |
+
recent_prev_file_name = name_split[0]
|
74 |
+
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
|
75 |
+
return os.path.join(head, new_file_name)
|
modules/visual_question_answering.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils import *
|
2 |
+
|
3 |
+
class VisualQuestionAnswering:
|
4 |
+
def __init__(self, device, pretrained_model_dir):
|
5 |
+
print("Initializing VisualQuestionAnswering to %s" % device)
|
6 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
7 |
+
self.device = device
|
8 |
+
self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-vqa-base")
|
9 |
+
self.model = BlipForQuestionAnswering.from_pretrained(
|
10 |
+
f"{pretrained_model_dir}/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
|
11 |
+
|
12 |
+
@prompts(name="Answer Question About The Image",
|
13 |
+
description="useful when you need an answer for a question based on an image. "
|
14 |
+
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
15 |
+
"The input to this tool should be a comma seperated string of two, representing the image_path and the question")
|
16 |
+
def inference(self, inputs):
|
17 |
+
image_path, question = inputs.split(",")
|
18 |
+
raw_image = Image.open(image_path).convert('RGB')
|
19 |
+
inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
|
20 |
+
out = self.model.generate(**inputs)
|
21 |
+
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
22 |
+
print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
|
23 |
+
f"Output Answer: {answer}")
|
24 |
+
return answer
|
requirement.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain==0.0.101
|
2 |
+
torch==1.12.1
|
3 |
+
torchvision==0.13.1
|
4 |
+
gradio==3.20.1
|
5 |
+
accelerate
|
6 |
+
addict
|
7 |
+
albumentations
|
8 |
+
basicsr
|
9 |
+
controlnet-aux
|
10 |
+
diffusers
|
11 |
+
einops
|
12 |
+
imageio
|
13 |
+
imageio-ffmpeg
|
14 |
+
invisible-watermark
|
15 |
+
kornia
|
16 |
+
numpy
|
17 |
+
omegaconf
|
18 |
+
open_clip_torch
|
19 |
+
openai
|
20 |
+
opencv-python
|
21 |
+
prettytable
|
22 |
+
safetensors
|
23 |
+
streamlit
|
24 |
+
test-tube
|
25 |
+
timm
|
26 |
+
torchmetrics
|
27 |
+
transformers
|
28 |
+
webdataset
|
29 |
+
yapf
|
30 |
+
numba
|
31 |
+
librosa
|
32 |
+
scipy
|
33 |
+
unidecode
|
34 |
+
openjtalk>=0.3.0.dev2
|
35 |
+
jamo
|
36 |
+
pypinyin
|
37 |
+
jieba
|
38 |
+
protobuf
|
39 |
+
pygtrans
|
40 |
+
cn2an
|
41 |
+
inflect
|
42 |
+
eng_to_ipa
|
43 |
+
ko_pron
|
44 |
+
indic_transliteration
|
45 |
+
num_thai
|
46 |
+
opencc
|
47 |
+
vosk
|
48 |
+
sounddevice
|
text/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
from text import cleaners
|
3 |
+
|
4 |
+
|
5 |
+
def text_to_sequence(text, symbols, cleaner_names):
|
6 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
7 |
+
Args:
|
8 |
+
text: string to convert to a sequence
|
9 |
+
cleaner_names: names of the cleaner functions to run the text through
|
10 |
+
Returns:
|
11 |
+
List of integers corresponding to the symbols in the text
|
12 |
+
'''
|
13 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
14 |
+
|
15 |
+
sequence = []
|
16 |
+
|
17 |
+
clean_text = _clean_text(text, cleaner_names)
|
18 |
+
for symbol in clean_text:
|
19 |
+
if symbol not in _symbol_to_id.keys():
|
20 |
+
continue
|
21 |
+
symbol_id = _symbol_to_id[symbol]
|
22 |
+
sequence += [symbol_id]
|
23 |
+
return sequence
|
24 |
+
|
25 |
+
|
26 |
+
def _clean_text(text, cleaner_names):
|
27 |
+
for name in cleaner_names:
|
28 |
+
cleaner = getattr(cleaners, name)
|
29 |
+
if not cleaner:
|
30 |
+
raise Exception('Unknown cleaner: %s' % name)
|
31 |
+
text = cleaner(text)
|
32 |
+
return text
|
text/cantonese.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import cn2an
|
3 |
+
import opencc
|
4 |
+
|
5 |
+
|
6 |
+
converter = opencc.OpenCC('jyutjyu')
|
7 |
+
|
8 |
+
# List of (Latin alphabet, ipa) pairs:
|
9 |
+
_latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
10 |
+
('A', 'ei˥'),
|
11 |
+
('B', 'biː˥'),
|
12 |
+
('C', 'siː˥'),
|
13 |
+
('D', 'tiː˥'),
|
14 |
+
('E', 'iː˥'),
|
15 |
+
('F', 'e˥fuː˨˩'),
|
16 |
+
('G', 'tsiː˥'),
|
17 |
+
('H', 'ɪk̚˥tsʰyː˨˩'),
|
18 |
+
('I', 'ɐi˥'),
|
19 |
+
('J', 'tsei˥'),
|
20 |
+
('K', 'kʰei˥'),
|
21 |
+
('L', 'e˥llou˨˩'),
|
22 |
+
('M', 'ɛːm˥'),
|
23 |
+
('N', 'ɛːn˥'),
|
24 |
+
('O', 'ou˥'),
|
25 |
+
('P', 'pʰiː˥'),
|
26 |
+
('Q', 'kʰiːu˥'),
|
27 |
+
('R', 'aː˥lou˨˩'),
|
28 |
+
('S', 'ɛː˥siː˨˩'),
|
29 |
+
('T', 'tʰiː˥'),
|
30 |
+
('U', 'juː˥'),
|
31 |
+
('V', 'wiː˥'),
|
32 |
+
('W', 'tʊk̚˥piː˥juː˥'),
|
33 |
+
('X', 'ɪk̚˥siː˨˩'),
|
34 |
+
('Y', 'waːi˥'),
|
35 |
+
('Z', 'iː˨sɛːt̚˥')
|
36 |
+
]]
|
37 |
+
|
38 |
+
|
39 |
+
def number_to_cantonese(text):
|
40 |
+
return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text)
|
41 |
+
|
42 |
+
|
43 |
+
def latin_to_ipa(text):
|
44 |
+
for regex, replacement in _latin_to_ipa:
|
45 |
+
text = re.sub(regex, replacement, text)
|
46 |
+
return text
|
47 |
+
|
48 |
+
|
49 |
+
def cantonese_to_ipa(text):
|
50 |
+
text = number_to_cantonese(text.upper())
|
51 |
+
text = converter.convert(text).replace('-','').replace('$',' ')
|
52 |
+
text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
|
53 |
+
text = re.sub(r'[、;:]', ',', text)
|
54 |
+
text = re.sub(r'\s*,\s*', ', ', text)
|
55 |
+
text = re.sub(r'\s*。\s*', '. ', text)
|
56 |
+
text = re.sub(r'\s*?\s*', '? ', text)
|
57 |
+
text = re.sub(r'\s*!\s*', '! ', text)
|
58 |
+
text = re.sub(r'\s*$', '', text)
|
59 |
+
return text
|
text/cleaners.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def japanese_cleaners(text):
|
5 |
+
from text.japanese import japanese_to_romaji_with_accent
|
6 |
+
text = japanese_to_romaji_with_accent(text)
|
7 |
+
text = re.sub(r'([A-Za-z])$', r'\1.', text)
|
8 |
+
return text
|
9 |
+
|
10 |
+
|
11 |
+
def japanese_cleaners2(text):
|
12 |
+
return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
|
13 |
+
|
14 |
+
|
15 |
+
def korean_cleaners(text):
|
16 |
+
'''Pipeline for Korean text'''
|
17 |
+
from text.korean import latin_to_hangul, number_to_hangul, divide_hangul
|
18 |
+
text = latin_to_hangul(text)
|
19 |
+
text = number_to_hangul(text)
|
20 |
+
text = divide_hangul(text)
|
21 |
+
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
|
22 |
+
return text
|
23 |
+
|
24 |
+
|
25 |
+
def chinese_cleaners(text):
|
26 |
+
'''Pipeline for Chinese text'''
|
27 |
+
from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo
|
28 |
+
text = number_to_chinese(text)
|
29 |
+
text = chinese_to_bopomofo(text)
|
30 |
+
text = latin_to_bopomofo(text)
|
31 |
+
text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
|
32 |
+
return text
|
33 |
+
|
34 |
+
|
35 |
+
def zh_ja_mixture_cleaners(text):
|
36 |
+
from text.mandarin import chinese_to_romaji
|
37 |
+
from text.japanese import japanese_to_romaji_with_accent
|
38 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
|
39 |
+
lambda x: chinese_to_romaji(x.group(1))+' ', text)
|
40 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent(
|
41 |
+
x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…')+' ', text)
|
42 |
+
text = re.sub(r'\s+$', '', text)
|
43 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
44 |
+
return text
|
45 |
+
|
46 |
+
|
47 |
+
def sanskrit_cleaners(text):
|
48 |
+
text = text.replace('॥', '।').replace('ॐ', 'ओम्')
|
49 |
+
text = re.sub(r'([^।])$', r'\1।', text)
|
50 |
+
return text
|
51 |
+
|
52 |
+
|
53 |
+
def cjks_cleaners(text):
|
54 |
+
from text.mandarin import chinese_to_lazy_ipa
|
55 |
+
from text.japanese import japanese_to_ipa
|
56 |
+
from text.korean import korean_to_lazy_ipa
|
57 |
+
from text.sanskrit import devanagari_to_ipa
|
58 |
+
from text.english import english_to_lazy_ipa
|
59 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
|
60 |
+
lambda x: chinese_to_lazy_ipa(x.group(1))+' ', text)
|
61 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]',
|
62 |
+
lambda x: japanese_to_ipa(x.group(1))+' ', text)
|
63 |
+
text = re.sub(r'\[KO\](.*?)\[KO\]',
|
64 |
+
lambda x: korean_to_lazy_ipa(x.group(1))+' ', text)
|
65 |
+
text = re.sub(r'\[SA\](.*?)\[SA\]',
|
66 |
+
lambda x: devanagari_to_ipa(x.group(1))+' ', text)
|
67 |
+
text = re.sub(r'\[EN\](.*?)\[EN\]',
|
68 |
+
lambda x: english_to_lazy_ipa(x.group(1))+' ', text)
|
69 |
+
text = re.sub(r'\s+$', '', text)
|
70 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
def cjke_cleaners(text):
|
75 |
+
from text.mandarin import chinese_to_lazy_ipa
|
76 |
+
from text.japanese import japanese_to_ipa
|
77 |
+
from text.korean import korean_to_ipa
|
78 |
+
from text.english import english_to_ipa2
|
79 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace(
|
80 |
+
'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')+' ', text)
|
81 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace(
|
82 |
+
'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')+' ', text)
|
83 |
+
text = re.sub(r'\[KO\](.*?)\[KO\]',
|
84 |
+
lambda x: korean_to_ipa(x.group(1))+' ', text)
|
85 |
+
text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace(
|
86 |
+
'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')+' ', text)
|
87 |
+
text = re.sub(r'\s+$', '', text)
|
88 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
89 |
+
return text
|
90 |
+
|
91 |
+
|
92 |
+
def cjke_cleaners2(text):
|
93 |
+
from text.mandarin import chinese_to_ipa
|
94 |
+
from text.japanese import japanese_to_ipa2
|
95 |
+
from text.korean import korean_to_ipa
|
96 |
+
from text.english import english_to_ipa2
|
97 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
|
98 |
+
lambda x: chinese_to_ipa(x.group(1))+' ', text)
|
99 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]',
|
100 |
+
lambda x: japanese_to_ipa2(x.group(1))+' ', text)
|
101 |
+
text = re.sub(r'\[KO\](.*?)\[KO\]',
|
102 |
+
lambda x: korean_to_ipa(x.group(1))+' ', text)
|
103 |
+
text = re.sub(r'\[EN\](.*?)\[EN\]',
|
104 |
+
lambda x: english_to_ipa2(x.group(1))+' ', text)
|
105 |
+
text = re.sub(r'\s+$', '', text)
|
106 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
107 |
+
return text
|
108 |
+
|
109 |
+
|
110 |
+
def thai_cleaners(text):
|
111 |
+
from text.thai import num_to_thai, latin_to_thai
|
112 |
+
text = num_to_thai(text)
|
113 |
+
text = latin_to_thai(text)
|
114 |
+
return text
|
115 |
+
|
116 |
+
|
117 |
+
def shanghainese_cleaners(text):
|
118 |
+
from text.shanghainese import shanghainese_to_ipa
|
119 |
+
text = shanghainese_to_ipa(text)
|
120 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
121 |
+
return text
|
122 |
+
|
123 |
+
|
124 |
+
def chinese_dialect_cleaners(text):
|
125 |
+
from text.mandarin import chinese_to_ipa2
|
126 |
+
from text.japanese import japanese_to_ipa3
|
127 |
+
from text.shanghainese import shanghainese_to_ipa
|
128 |
+
from text.cantonese import cantonese_to_ipa
|
129 |
+
from text.english import english_to_lazy_ipa2
|
130 |
+
from text.ngu_dialect import ngu_dialect_to_ipa
|
131 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
|
132 |
+
lambda x: chinese_to_ipa2(x.group(1))+' ', text)
|
133 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]',
|
134 |
+
lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text)
|
135 |
+
text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5',
|
136 |
+
'˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text)
|
137 |
+
text = re.sub(r'\[GD\](.*?)\[GD\]',
|
138 |
+
lambda x: cantonese_to_ipa(x.group(1))+' ', text)
|
139 |
+
text = re.sub(r'\[EN\](.*?)\[EN\]',
|
140 |
+
lambda x: english_to_lazy_ipa2(x.group(1))+' ', text)
|
141 |
+
text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group(
|
142 |
+
1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text)
|
143 |
+
text = re.sub(r'\s+$', '', text)
|
144 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
145 |
+
return text
|
text/english.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
# Regular expression matching whitespace:
|
17 |
+
|
18 |
+
|
19 |
+
import re
|
20 |
+
import inflect
|
21 |
+
from unidecode import unidecode
|
22 |
+
import eng_to_ipa as ipa
|
23 |
+
_inflect = inflect.engine()
|
24 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
25 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
26 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
27 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
28 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
29 |
+
_number_re = re.compile(r'[0-9]+')
|
30 |
+
|
31 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
32 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
33 |
+
('mrs', 'misess'),
|
34 |
+
('mr', 'mister'),
|
35 |
+
('dr', 'doctor'),
|
36 |
+
('st', 'saint'),
|
37 |
+
('co', 'company'),
|
38 |
+
('jr', 'junior'),
|
39 |
+
('maj', 'major'),
|
40 |
+
('gen', 'general'),
|
41 |
+
('drs', 'doctors'),
|
42 |
+
('rev', 'reverend'),
|
43 |
+
('lt', 'lieutenant'),
|
44 |
+
('hon', 'honorable'),
|
45 |
+
('sgt', 'sergeant'),
|
46 |
+
('capt', 'captain'),
|
47 |
+
('esq', 'esquire'),
|
48 |
+
('ltd', 'limited'),
|
49 |
+
('col', 'colonel'),
|
50 |
+
('ft', 'fort'),
|
51 |
+
]]
|
52 |
+
|
53 |
+
|
54 |
+
# List of (ipa, lazy ipa) pairs:
|
55 |
+
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
56 |
+
('r', 'ɹ'),
|
57 |
+
('æ', 'e'),
|
58 |
+
('ɑ', 'a'),
|
59 |
+
('ɔ', 'o'),
|
60 |
+
('ð', 'z'),
|
61 |
+
('θ', 's'),
|
62 |
+
('ɛ', 'e'),
|
63 |
+
('ɪ', 'i'),
|
64 |
+
('ʊ', 'u'),
|
65 |
+
('ʒ', 'ʥ'),
|
66 |
+
('ʤ', 'ʥ'),
|
67 |
+
('ˈ', '↓'),
|
68 |
+
]]
|
69 |
+
|
70 |
+
# List of (ipa, lazy ipa2) pairs:
|
71 |
+
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
72 |
+
('r', 'ɹ'),
|
73 |
+
('ð', 'z'),
|
74 |
+
('θ', 's'),
|
75 |
+
('ʒ', 'ʑ'),
|
76 |
+
('ʤ', 'dʑ'),
|
77 |
+
('ˈ', '↓'),
|
78 |
+
]]
|
79 |
+
|
80 |
+
# List of (ipa, ipa2) pairs
|
81 |
+
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
82 |
+
('r', 'ɹ'),
|
83 |
+
('ʤ', 'dʒ'),
|
84 |
+
('ʧ', 'tʃ')
|
85 |
+
]]
|
86 |
+
|
87 |
+
|
88 |
+
def expand_abbreviations(text):
|
89 |
+
for regex, replacement in _abbreviations:
|
90 |
+
text = re.sub(regex, replacement, text)
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
def collapse_whitespace(text):
|
95 |
+
return re.sub(r'\s+', ' ', text)
|
96 |
+
|
97 |
+
|
98 |
+
def _remove_commas(m):
|
99 |
+
return m.group(1).replace(',', '')
|
100 |
+
|
101 |
+
|
102 |
+
def _expand_decimal_point(m):
|
103 |
+
return m.group(1).replace('.', ' point ')
|
104 |
+
|
105 |
+
|
106 |
+
def _expand_dollars(m):
|
107 |
+
match = m.group(1)
|
108 |
+
parts = match.split('.')
|
109 |
+
if len(parts) > 2:
|
110 |
+
return match + ' dollars' # Unexpected format
|
111 |
+
dollars = int(parts[0]) if parts[0] else 0
|
112 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
113 |
+
if dollars and cents:
|
114 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
115 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
116 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
117 |
+
elif dollars:
|
118 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
119 |
+
return '%s %s' % (dollars, dollar_unit)
|
120 |
+
elif cents:
|
121 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
122 |
+
return '%s %s' % (cents, cent_unit)
|
123 |
+
else:
|
124 |
+
return 'zero dollars'
|
125 |
+
|
126 |
+
|
127 |
+
def _expand_ordinal(m):
|
128 |
+
return _inflect.number_to_words(m.group(0))
|
129 |
+
|
130 |
+
|
131 |
+
def _expand_number(m):
|
132 |
+
num = int(m.group(0))
|
133 |
+
if num > 1000 and num < 3000:
|
134 |
+
if num == 2000:
|
135 |
+
return 'two thousand'
|
136 |
+
elif num > 2000 and num < 2010:
|
137 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
138 |
+
elif num % 100 == 0:
|
139 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
140 |
+
else:
|
141 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
142 |
+
else:
|
143 |
+
return _inflect.number_to_words(num, andword='')
|
144 |
+
|
145 |
+
|
146 |
+
def normalize_numbers(text):
|
147 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
148 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
149 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
150 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
151 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
152 |
+
text = re.sub(_number_re, _expand_number, text)
|
153 |
+
return text
|
154 |
+
|
155 |
+
|
156 |
+
def mark_dark_l(text):
|
157 |
+
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
|
158 |
+
|
159 |
+
|
160 |
+
def english_to_ipa(text):
|
161 |
+
text = unidecode(text).lower()
|
162 |
+
text = expand_abbreviations(text)
|
163 |
+
text = normalize_numbers(text)
|
164 |
+
phonemes = ipa.convert(text)
|
165 |
+
phonemes = collapse_whitespace(phonemes)
|
166 |
+
return phonemes
|
167 |
+
|
168 |
+
|
169 |
+
def english_to_lazy_ipa(text):
|
170 |
+
text = english_to_ipa(text)
|
171 |
+
for regex, replacement in _lazy_ipa:
|
172 |
+
text = re.sub(regex, replacement, text)
|
173 |
+
return text
|
174 |
+
|
175 |
+
|
176 |
+
def english_to_ipa2(text):
|
177 |
+
text = english_to_ipa(text)
|
178 |
+
text = mark_dark_l(text)
|
179 |
+
for regex, replacement in _ipa_to_ipa2:
|
180 |
+
text = re.sub(regex, replacement, text)
|
181 |
+
return text.replace('...', '…')
|
182 |
+
|
183 |
+
|
184 |
+
def english_to_lazy_ipa2(text):
|
185 |
+
text = english_to_ipa(text)
|
186 |
+
for regex, replacement in _lazy_ipa2:
|
187 |
+
text = re.sub(regex, replacement, text)
|
188 |
+
return text
|
text/japanese.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from unidecode import unidecode
|
3 |
+
import text.pyopenjtalk as pyopenjtalk
|
4 |
+
|
5 |
+
|
6 |
+
# Regular expression matching Japanese without punctuation marks:
|
7 |
+
_japanese_characters = re.compile(
|
8 |
+
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
9 |
+
|
10 |
+
# Regular expression matching non-Japanese characters or punctuation marks:
|
11 |
+
_japanese_marks = re.compile(
|
12 |
+
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
13 |
+
|
14 |
+
# List of (symbol, Japanese) pairs for marks:
|
15 |
+
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
|
16 |
+
('%', 'パーセント')
|
17 |
+
]]
|
18 |
+
|
19 |
+
# List of (romaji, ipa) pairs for marks:
|
20 |
+
_romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
21 |
+
('ts', 'ʦ'),
|
22 |
+
('u', 'ɯ'),
|
23 |
+
('j', 'ʥ'),
|
24 |
+
('y', 'j'),
|
25 |
+
('ni', 'n^i'),
|
26 |
+
('nj', 'n^'),
|
27 |
+
('hi', 'çi'),
|
28 |
+
('hj', 'ç'),
|
29 |
+
('f', 'ɸ'),
|
30 |
+
('I', 'i*'),
|
31 |
+
('U', 'ɯ*'),
|
32 |
+
('r', 'ɾ')
|
33 |
+
]]
|
34 |
+
|
35 |
+
# List of (romaji, ipa2) pairs for marks:
|
36 |
+
_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
37 |
+
('u', 'ɯ'),
|
38 |
+
('ʧ', 'tʃ'),
|
39 |
+
('j', 'dʑ'),
|
40 |
+
('y', 'j'),
|
41 |
+
('ni', 'n^i'),
|
42 |
+
('nj', 'n^'),
|
43 |
+
('hi', 'çi'),
|
44 |
+
('hj', 'ç'),
|
45 |
+
('f', 'ɸ'),
|
46 |
+
('I', 'i*'),
|
47 |
+
('U', 'ɯ*'),
|
48 |
+
('r', 'ɾ')
|
49 |
+
]]
|
50 |
+
|
51 |
+
# List of (consonant, sokuon) pairs:
|
52 |
+
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
|
53 |
+
(r'Q([↑↓]*[kg])', r'k#\1'),
|
54 |
+
(r'Q([↑↓]*[tdjʧ])', r't#\1'),
|
55 |
+
(r'Q([↑↓]*[sʃ])', r's\1'),
|
56 |
+
(r'Q([↑↓]*[pb])', r'p#\1')
|
57 |
+
]]
|
58 |
+
|
59 |
+
# List of (consonant, hatsuon) pairs:
|
60 |
+
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
|
61 |
+
(r'N([↑↓]*[pbm])', r'm\1'),
|
62 |
+
(r'N([↑↓]*[ʧʥj])', r'n^\1'),
|
63 |
+
(r'N([↑↓]*[tdn])', r'n\1'),
|
64 |
+
(r'N([↑↓]*[kg])', r'ŋ\1')
|
65 |
+
]]
|
66 |
+
|
67 |
+
|
68 |
+
def symbols_to_japanese(text):
|
69 |
+
for regex, replacement in _symbols_to_japanese:
|
70 |
+
text = re.sub(regex, replacement, text)
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
def japanese_to_romaji_with_accent(text):
|
75 |
+
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
|
76 |
+
text = symbols_to_japanese(text)
|
77 |
+
sentences = re.split(_japanese_marks, text)
|
78 |
+
marks = re.findall(_japanese_marks, text)
|
79 |
+
text = ''
|
80 |
+
for i, sentence in enumerate(sentences):
|
81 |
+
if re.match(_japanese_characters, sentence):
|
82 |
+
if text != '':
|
83 |
+
text += ' '
|
84 |
+
labels = pyopenjtalk.extract_fullcontext(sentence)
|
85 |
+
for n, label in enumerate(labels):
|
86 |
+
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
|
87 |
+
if phoneme not in ['sil', 'pau']:
|
88 |
+
text += phoneme.replace('ch', 'ʧ').replace('sh',
|
89 |
+
'ʃ').replace('cl', 'Q')
|
90 |
+
else:
|
91 |
+
continue
|
92 |
+
# n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
|
93 |
+
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
|
94 |
+
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
|
95 |
+
a3 = int(re.search(r"\+(\d+)/", label).group(1))
|
96 |
+
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
|
97 |
+
a2_next = -1
|
98 |
+
else:
|
99 |
+
a2_next = int(
|
100 |
+
re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
|
101 |
+
# Accent phrase boundary
|
102 |
+
if a3 == 1 and a2_next == 1:
|
103 |
+
text += ' '
|
104 |
+
# Falling
|
105 |
+
elif a1 == 0 and a2_next == a2 + 1:
|
106 |
+
text += '↓'
|
107 |
+
# Rising
|
108 |
+
elif a2 == 1 and a2_next == 2:
|
109 |
+
text += '↑'
|
110 |
+
if i < len(marks):
|
111 |
+
text += unidecode(marks[i]).replace(' ', '')
|
112 |
+
return text
|
113 |
+
|
114 |
+
|
115 |
+
def get_real_sokuon(text):
|
116 |
+
for regex, replacement in _real_sokuon:
|
117 |
+
text = re.sub(regex, replacement, text)
|
118 |
+
return text
|
119 |
+
|
120 |
+
|
121 |
+
def get_real_hatsuon(text):
|
122 |
+
for regex, replacement in _real_hatsuon:
|
123 |
+
text = re.sub(regex, replacement, text)
|
124 |
+
return text
|
125 |
+
|
126 |
+
|
127 |
+
def japanese_to_ipa(text):
|
128 |
+
text = japanese_to_romaji_with_accent(text).replace('...', '…')
|
129 |
+
text = re.sub(
|
130 |
+
r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
|
131 |
+
text = get_real_sokuon(text)
|
132 |
+
text = get_real_hatsuon(text)
|
133 |
+
for regex, replacement in _romaji_to_ipa:
|
134 |
+
text = re.sub(regex, replacement, text)
|
135 |
+
return text
|
136 |
+
|
137 |
+
|
138 |
+
def japanese_to_ipa2(text):
|
139 |
+
text = japanese_to_romaji_with_accent(text).replace('...', '…')
|
140 |
+
text = get_real_sokuon(text)
|
141 |
+
text = get_real_hatsuon(text)
|
142 |
+
for regex, replacement in _romaji_to_ipa2:
|
143 |
+
text = re.sub(regex, replacement, text)
|
144 |
+
return text
|
145 |
+
|
146 |
+
|
147 |
+
def japanese_to_ipa3(text):
|
148 |
+
text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
|
149 |
+
'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
|
150 |
+
text = re.sub(
|
151 |
+
r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
|
152 |
+
text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
|
153 |
+
return text
|
text/korean.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from jamo import h2j, j2hcj
|
3 |
+
import ko_pron
|
4 |
+
|
5 |
+
|
6 |
+
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
7 |
+
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
8 |
+
|
9 |
+
# List of (hangul, hangul divided) pairs:
|
10 |
+
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
11 |
+
('ㄳ', 'ㄱㅅ'),
|
12 |
+
('ㄵ', 'ㄴㅈ'),
|
13 |
+
('ㄶ', 'ㄴㅎ'),
|
14 |
+
('ㄺ', 'ㄹㄱ'),
|
15 |
+
('ㄻ', 'ㄹㅁ'),
|
16 |
+
('ㄼ', 'ㄹㅂ'),
|
17 |
+
('ㄽ', 'ㄹㅅ'),
|
18 |
+
('ㄾ', 'ㄹㅌ'),
|
19 |
+
('ㄿ', 'ㄹㅍ'),
|
20 |
+
('ㅀ', 'ㄹㅎ'),
|
21 |
+
('ㅄ', 'ㅂㅅ'),
|
22 |
+
('ㅘ', 'ㅗㅏ'),
|
23 |
+
('ㅙ', 'ㅗㅐ'),
|
24 |
+
('ㅚ', 'ㅗㅣ'),
|
25 |
+
('ㅝ', 'ㅜㅓ'),
|
26 |
+
('ㅞ', 'ㅜㅔ'),
|
27 |
+
('ㅟ', 'ㅜㅣ'),
|
28 |
+
('ㅢ', 'ㅡㅣ'),
|
29 |
+
('ㅑ', 'ㅣㅏ'),
|
30 |
+
('ㅒ', 'ㅣㅐ'),
|
31 |
+
('ㅕ', 'ㅣㅓ'),
|
32 |
+
('ㅖ', 'ㅣㅔ'),
|
33 |
+
('ㅛ', 'ㅣㅗ'),
|
34 |
+
('ㅠ', 'ㅣㅜ')
|
35 |
+
]]
|
36 |
+
|
37 |
+
# List of (Latin alphabet, hangul) pairs:
|
38 |
+
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
39 |
+
('a', '에이'),
|
40 |
+
('b', '비'),
|
41 |
+
('c', '시'),
|
42 |
+
('d', '디'),
|
43 |
+
('e', '이'),
|
44 |
+
('f', '에프'),
|
45 |
+
('g', '지'),
|
46 |
+
('h', '에이치'),
|
47 |
+
('i', '아이'),
|
48 |
+
('j', '제이'),
|
49 |
+
('k', '케이'),
|
50 |
+
('l', '엘'),
|
51 |
+
('m', '엠'),
|
52 |
+
('n', '엔'),
|
53 |
+
('o', '오'),
|
54 |
+
('p', '피'),
|
55 |
+
('q', '큐'),
|
56 |
+
('r', '아르'),
|
57 |
+
('s', '에스'),
|
58 |
+
('t', '티'),
|
59 |
+
('u', '유'),
|
60 |
+
('v', '브이'),
|
61 |
+
('w', '더블유'),
|
62 |
+
('x', '엑스'),
|
63 |
+
('y', '와이'),
|
64 |
+
('z', '제트')
|
65 |
+
]]
|
66 |
+
|
67 |
+
# List of (ipa, lazy ipa) pairs:
|
68 |
+
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
69 |
+
('t͡ɕ','ʧ'),
|
70 |
+
('d͡ʑ','ʥ'),
|
71 |
+
('ɲ','n^'),
|
72 |
+
('ɕ','ʃ'),
|
73 |
+
('ʷ','w'),
|
74 |
+
('ɭ','l`'),
|
75 |
+
('ʎ','ɾ'),
|
76 |
+
('ɣ','ŋ'),
|
77 |
+
('ɰ','ɯ'),
|
78 |
+
('ʝ','j'),
|
79 |
+
('ʌ','ə'),
|
80 |
+
('ɡ','g'),
|
81 |
+
('\u031a','#'),
|
82 |
+
('\u0348','='),
|
83 |
+
('\u031e',''),
|
84 |
+
('\u0320',''),
|
85 |
+
('\u0339','')
|
86 |
+
]]
|
87 |
+
|
88 |
+
|
89 |
+
def latin_to_hangul(text):
|
90 |
+
for regex, replacement in _latin_to_hangul:
|
91 |
+
text = re.sub(regex, replacement, text)
|
92 |
+
return text
|
93 |
+
|
94 |
+
|
95 |
+
def divide_hangul(text):
|
96 |
+
text = j2hcj(h2j(text))
|
97 |
+
for regex, replacement in _hangul_divided:
|
98 |
+
text = re.sub(regex, replacement, text)
|
99 |
+
return text
|
100 |
+
|
101 |
+
|
102 |
+
def hangul_number(num, sino=True):
|
103 |
+
'''Reference https://github.com/Kyubyong/g2pK'''
|
104 |
+
num = re.sub(',', '', num)
|
105 |
+
|
106 |
+
if num == '0':
|
107 |
+
return '영'
|
108 |
+
if not sino and num == '20':
|
109 |
+
return '스무'
|
110 |
+
|
111 |
+
digits = '123456789'
|
112 |
+
names = '일이삼사오육칠팔구'
|
113 |
+
digit2name = {d: n for d, n in zip(digits, names)}
|
114 |
+
|
115 |
+
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
116 |
+
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
117 |
+
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
118 |
+
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
119 |
+
|
120 |
+
spelledout = []
|
121 |
+
for i, digit in enumerate(num):
|
122 |
+
i = len(num) - i - 1
|
123 |
+
if sino:
|
124 |
+
if i == 0:
|
125 |
+
name = digit2name.get(digit, '')
|
126 |
+
elif i == 1:
|
127 |
+
name = digit2name.get(digit, '') + '십'
|
128 |
+
name = name.replace('일십', '십')
|
129 |
+
else:
|
130 |
+
if i == 0:
|
131 |
+
name = digit2mod.get(digit, '')
|
132 |
+
elif i == 1:
|
133 |
+
name = digit2dec.get(digit, '')
|
134 |
+
if digit == '0':
|
135 |
+
if i % 4 == 0:
|
136 |
+
last_three = spelledout[-min(3, len(spelledout)):]
|
137 |
+
if ''.join(last_three) == '':
|
138 |
+
spelledout.append('')
|
139 |
+
continue
|
140 |
+
else:
|
141 |
+
spelledout.append('')
|
142 |
+
continue
|
143 |
+
if i == 2:
|
144 |
+
name = digit2name.get(digit, '') + '백'
|
145 |
+
name = name.replace('일백', '백')
|
146 |
+
elif i == 3:
|
147 |
+
name = digit2name.get(digit, '') + '천'
|
148 |
+
name = name.replace('일천', '천')
|
149 |
+
elif i == 4:
|
150 |
+
name = digit2name.get(digit, '') + '만'
|
151 |
+
name = name.replace('일만', '만')
|
152 |
+
elif i == 5:
|
153 |
+
name = digit2name.get(digit, '') + '십'
|
154 |
+
name = name.replace('일십', '십')
|
155 |
+
elif i == 6:
|
156 |
+
name = digit2name.get(digit, '') + '백'
|
157 |
+
name = name.replace('일백', '백')
|
158 |
+
elif i == 7:
|
159 |
+
name = digit2name.get(digit, '') + '천'
|
160 |
+
name = name.replace('일천', '천')
|
161 |
+
elif i == 8:
|
162 |
+
name = digit2name.get(digit, '') + '억'
|
163 |
+
elif i == 9:
|
164 |
+
name = digit2name.get(digit, '') + '십'
|
165 |
+
elif i == 10:
|
166 |
+
name = digit2name.get(digit, '') + '백'
|
167 |
+
elif i == 11:
|
168 |
+
name = digit2name.get(digit, '') + '천'
|
169 |
+
elif i == 12:
|
170 |
+
name = digit2name.get(digit, '') + '조'
|
171 |
+
elif i == 13:
|
172 |
+
name = digit2name.get(digit, '') + '십'
|
173 |
+
elif i == 14:
|
174 |
+
name = digit2name.get(digit, '') + '백'
|
175 |
+
elif i == 15:
|
176 |
+
name = digit2name.get(digit, '') + '천'
|
177 |
+
spelledout.append(name)
|
178 |
+
return ''.join(elem for elem in spelledout)
|
179 |
+
|
180 |
+
|
181 |
+
def number_to_hangul(text):
|
182 |
+
'''Reference https://github.com/Kyubyong/g2pK'''
|
183 |
+
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
184 |
+
for token in tokens:
|
185 |
+
num, classifier = token
|
186 |
+
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
187 |
+
spelledout = hangul_number(num, sino=False)
|
188 |
+
else:
|
189 |
+
spelledout = hangul_number(num, sino=True)
|
190 |
+
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
191 |
+
# digit by digit for remaining digits
|
192 |
+
digits = '0123456789'
|
193 |
+
names = '영일이삼사오육칠팔구'
|
194 |
+
for d, n in zip(digits, names):
|
195 |
+
text = text.replace(d, n)
|
196 |
+
return text
|
197 |
+
|
198 |
+
|
199 |
+
def korean_to_lazy_ipa(text):
|
200 |
+
text = latin_to_hangul(text)
|
201 |
+
text = number_to_hangul(text)
|
202 |
+
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
|
203 |
+
for regex, replacement in _ipa_to_lazy_ipa:
|
204 |
+
text = re.sub(regex, replacement, text)
|
205 |
+
return text
|
206 |
+
|
207 |
+
|
208 |
+
def korean_to_ipa(text):
|
209 |
+
text = korean_to_lazy_ipa(text)
|
210 |
+
return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
|
text/mandarin.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from pypinyin import lazy_pinyin, BOPOMOFO
|
5 |
+
import jieba
|
6 |
+
import cn2an
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logging.getLogger('jieba').setLevel(logging.WARNING)
|
10 |
+
jieba.set_dictionary(r'./jieba/dict.txt')
|
11 |
+
jieba.initialize()
|
12 |
+
|
13 |
+
|
14 |
+
# List of (Latin alphabet, bopomofo) pairs:
|
15 |
+
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
16 |
+
('a', 'ㄟˉ'),
|
17 |
+
('b', 'ㄅㄧˋ'),
|
18 |
+
('c', 'ㄙㄧˉ'),
|
19 |
+
('d', 'ㄉㄧˋ'),
|
20 |
+
('e', 'ㄧˋ'),
|
21 |
+
('f', 'ㄝˊㄈㄨˋ'),
|
22 |
+
('g', 'ㄐㄧˋ'),
|
23 |
+
('h', 'ㄝˇㄑㄩˋ'),
|
24 |
+
('i', 'ㄞˋ'),
|
25 |
+
('j', 'ㄐㄟˋ'),
|
26 |
+
('k', 'ㄎㄟˋ'),
|
27 |
+
('l', 'ㄝˊㄛˋ'),
|
28 |
+
('m', 'ㄝˊㄇㄨˋ'),
|
29 |
+
('n', 'ㄣˉ'),
|
30 |
+
('o', 'ㄡˉ'),
|
31 |
+
('p', 'ㄆㄧˉ'),
|
32 |
+
('q', 'ㄎㄧㄡˉ'),
|
33 |
+
('r', 'ㄚˋ'),
|
34 |
+
('s', 'ㄝˊㄙˋ'),
|
35 |
+
('t', 'ㄊㄧˋ'),
|
36 |
+
('u', 'ㄧㄡˉ'),
|
37 |
+
('v', 'ㄨㄧˉ'),
|
38 |
+
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
|
39 |
+
('x', 'ㄝˉㄎㄨˋㄙˋ'),
|
40 |
+
('y', 'ㄨㄞˋ'),
|
41 |
+
('z', 'ㄗㄟˋ')
|
42 |
+
]]
|
43 |
+
|
44 |
+
# List of (bopomofo, romaji) pairs:
|
45 |
+
_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
|
46 |
+
('ㄅㄛ', 'p⁼wo'),
|
47 |
+
('ㄆㄛ', 'pʰwo'),
|
48 |
+
('ㄇㄛ', 'mwo'),
|
49 |
+
('ㄈㄛ', 'fwo'),
|
50 |
+
('ㄅ', 'p⁼'),
|
51 |
+
('ㄆ', 'pʰ'),
|
52 |
+
('ㄇ', 'm'),
|
53 |
+
('ㄈ', 'f'),
|
54 |
+
('ㄉ', 't⁼'),
|
55 |
+
('ㄊ', 'tʰ'),
|
56 |
+
('ㄋ', 'n'),
|
57 |
+
('ㄌ', 'l'),
|
58 |
+
('ㄍ', 'k⁼'),
|
59 |
+
('ㄎ', 'kʰ'),
|
60 |
+
('ㄏ', 'h'),
|
61 |
+
('ㄐ', 'ʧ⁼'),
|
62 |
+
('ㄑ', 'ʧʰ'),
|
63 |
+
('ㄒ', 'ʃ'),
|
64 |
+
('ㄓ', 'ʦ`⁼'),
|
65 |
+
('ㄔ', 'ʦ`ʰ'),
|
66 |
+
('ㄕ', 's`'),
|
67 |
+
('ㄖ', 'ɹ`'),
|
68 |
+
('ㄗ', 'ʦ⁼'),
|
69 |
+
('ㄘ', 'ʦʰ'),
|
70 |
+
('ㄙ', 's'),
|
71 |
+
('ㄚ', 'a'),
|
72 |
+
('ㄛ', 'o'),
|
73 |
+
('ㄜ', 'ə'),
|
74 |
+
('ㄝ', 'e'),
|
75 |
+
('ㄞ', 'ai'),
|
76 |
+
('ㄟ', 'ei'),
|
77 |
+
('ㄠ', 'au'),
|
78 |
+
('ㄡ', 'ou'),
|
79 |
+
('ㄧㄢ', 'yeNN'),
|
80 |
+
('ㄢ', 'aNN'),
|
81 |
+
('ㄧㄣ', 'iNN'),
|
82 |
+
('ㄣ', 'əNN'),
|
83 |
+
('ㄤ', 'aNg'),
|
84 |
+
('ㄧㄥ', 'iNg'),
|
85 |
+
('ㄨㄥ', 'uNg'),
|
86 |
+
('ㄩㄥ', 'yuNg'),
|
87 |
+
('ㄥ', 'əNg'),
|
88 |
+
('ㄦ', 'əɻ'),
|
89 |
+
('ㄧ', 'i'),
|
90 |
+
('ㄨ', 'u'),
|
91 |
+
('ㄩ', 'ɥ'),
|
92 |
+
('ˉ', '→'),
|
93 |
+
('ˊ', '↑'),
|
94 |
+
('ˇ', '↓↑'),
|
95 |
+
('ˋ', '↓'),
|
96 |
+
('˙', ''),
|
97 |
+
(',', ','),
|
98 |
+
('。', '.'),
|
99 |
+
('!', '!'),
|
100 |
+
('?', '?'),
|
101 |
+
('—', '-')
|
102 |
+
]]
|
103 |
+
|
104 |
+
# List of (romaji, ipa) pairs:
|
105 |
+
_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
106 |
+
('ʃy', 'ʃ'),
|
107 |
+
('ʧʰy', 'ʧʰ'),
|
108 |
+
('ʧ⁼y', 'ʧ⁼'),
|
109 |
+
('NN', 'n'),
|
110 |
+
('Ng', 'ŋ'),
|
111 |
+
('y', 'j'),
|
112 |
+
('h', 'x')
|
113 |
+
]]
|
114 |
+
|
115 |
+
# List of (bopomofo, ipa) pairs:
|
116 |
+
_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
117 |
+
('ㄅㄛ', 'p⁼wo'),
|
118 |
+
('ㄆㄛ', 'pʰwo'),
|
119 |
+
('ㄇㄛ', 'mwo'),
|
120 |
+
('ㄈㄛ', 'fwo'),
|
121 |
+
('ㄅ', 'p⁼'),
|
122 |
+
('ㄆ', 'pʰ'),
|
123 |
+
('ㄇ', 'm'),
|
124 |
+
('ㄈ', 'f'),
|
125 |
+
('ㄉ', 't⁼'),
|
126 |
+
('ㄊ', 'tʰ'),
|
127 |
+
('ㄋ', 'n'),
|
128 |
+
('ㄌ', 'l'),
|
129 |
+
('ㄍ', 'k⁼'),
|
130 |
+
('ㄎ', 'kʰ'),
|
131 |
+
('ㄏ', 'x'),
|
132 |
+
('ㄐ', 'tʃ⁼'),
|
133 |
+
('ㄑ', 'tʃʰ'),
|
134 |
+
('ㄒ', 'ʃ'),
|
135 |
+
('ㄓ', 'ts`⁼'),
|
136 |
+
('ㄔ', 'ts`ʰ'),
|
137 |
+
('ㄕ', 's`'),
|
138 |
+
('ㄖ', 'ɹ`'),
|
139 |
+
('ㄗ', 'ts⁼'),
|
140 |
+
('ㄘ', 'tsʰ'),
|
141 |
+
('ㄙ', 's'),
|
142 |
+
('ㄚ', 'a'),
|
143 |
+
('ㄛ', 'o'),
|
144 |
+
('ㄜ', 'ə'),
|
145 |
+
('ㄝ', 'ɛ'),
|
146 |
+
('ㄞ', 'aɪ'),
|
147 |
+
('ㄟ', 'eɪ'),
|
148 |
+
('ㄠ', 'ɑʊ'),
|
149 |
+
('ㄡ', 'oʊ'),
|
150 |
+
('ㄧㄢ', 'jɛn'),
|
151 |
+
('ㄩㄢ', 'ɥæn'),
|
152 |
+
('ㄢ', 'an'),
|
153 |
+
('ㄧㄣ', 'in'),
|
154 |
+
('ㄩㄣ', 'ɥn'),
|
155 |
+
('ㄣ', 'ən'),
|
156 |
+
('ㄤ', 'ɑŋ'),
|
157 |
+
('ㄧㄥ', 'iŋ'),
|
158 |
+
('ㄨㄥ', 'ʊŋ'),
|
159 |
+
('ㄩㄥ', 'jʊŋ'),
|
160 |
+
('ㄥ', 'əŋ'),
|
161 |
+
('ㄦ', 'əɻ'),
|
162 |
+
('ㄧ', 'i'),
|
163 |
+
('ㄨ', 'u'),
|
164 |
+
('ㄩ', 'ɥ'),
|
165 |
+
('ˉ', '→'),
|
166 |
+
('ˊ', '↑'),
|
167 |
+
('ˇ', '↓↑'),
|
168 |
+
('ˋ', '↓'),
|
169 |
+
('˙', ''),
|
170 |
+
(',', ','),
|
171 |
+
('。', '.'),
|
172 |
+
('!', '!'),
|
173 |
+
('?', '?'),
|
174 |
+
('—', '-')
|
175 |
+
]]
|
176 |
+
|
177 |
+
# List of (bopomofo, ipa2) pairs:
|
178 |
+
_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
179 |
+
('ㄅㄛ', 'pwo'),
|
180 |
+
('ㄆㄛ', 'pʰwo'),
|
181 |
+
('ㄇㄛ', 'mwo'),
|
182 |
+
('ㄈㄛ', 'fwo'),
|
183 |
+
('ㄅ', 'p'),
|
184 |
+
('ㄆ', 'pʰ'),
|
185 |
+
('ㄇ', 'm'),
|
186 |
+
('ㄈ', 'f'),
|
187 |
+
('ㄉ', 't'),
|
188 |
+
('ㄊ', 'tʰ'),
|
189 |
+
('ㄋ', 'n'),
|
190 |
+
('ㄌ', 'l'),
|
191 |
+
('ㄍ', 'k'),
|
192 |
+
('ㄎ', 'kʰ'),
|
193 |
+
('ㄏ', 'h'),
|
194 |
+
('ㄐ', 'tɕ'),
|
195 |
+
('ㄑ', 'tɕʰ'),
|
196 |
+
('ㄒ', 'ɕ'),
|
197 |
+
('ㄓ', 'tʂ'),
|
198 |
+
('ㄔ', 'tʂʰ'),
|
199 |
+
('ㄕ', 'ʂ'),
|
200 |
+
('ㄖ', 'ɻ'),
|
201 |
+
('ㄗ', 'ts'),
|
202 |
+
('ㄘ', 'tsʰ'),
|
203 |
+
('ㄙ', 's'),
|
204 |
+
('ㄚ', 'a'),
|
205 |
+
('ㄛ', 'o'),
|
206 |
+
('ㄜ', 'ɤ'),
|
207 |
+
('ㄝ', 'ɛ'),
|
208 |
+
('ㄞ', 'aɪ'),
|
209 |
+
('ㄟ', 'eɪ'),
|
210 |
+
('ㄠ', 'ɑʊ'),
|
211 |
+
('ㄡ', 'oʊ'),
|
212 |
+
('ㄧㄢ', 'jɛn'),
|
213 |
+
('ㄩㄢ', 'yæn'),
|
214 |
+
('ㄢ', 'an'),
|
215 |
+
('ㄧㄣ', 'in'),
|
216 |
+
('ㄩㄣ', 'yn'),
|
217 |
+
('ㄣ', 'ən'),
|
218 |
+
('ㄤ', 'ɑŋ'),
|
219 |
+
('ㄧㄥ', 'iŋ'),
|
220 |
+
('ㄨㄥ', 'ʊŋ'),
|
221 |
+
('ㄩㄥ', 'jʊŋ'),
|
222 |
+
('ㄥ', 'ɤŋ'),
|
223 |
+
('ㄦ', 'əɻ'),
|
224 |
+
('ㄧ', 'i'),
|
225 |
+
('ㄨ', 'u'),
|
226 |
+
('ㄩ', 'y'),
|
227 |
+
('ˉ', '˥'),
|
228 |
+
('ˊ', '˧˥'),
|
229 |
+
('ˇ', '˨˩˦'),
|
230 |
+
('ˋ', '˥˩'),
|
231 |
+
('˙', ''),
|
232 |
+
(',', ','),
|
233 |
+
('。', '.'),
|
234 |
+
('!', '!'),
|
235 |
+
('?', '?'),
|
236 |
+
('—', '-')
|
237 |
+
]]
|
238 |
+
|
239 |
+
|
240 |
+
def number_to_chinese(text):
|
241 |
+
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
|
242 |
+
for number in numbers:
|
243 |
+
text = text.replace(number, cn2an.an2cn(number), 1)
|
244 |
+
return text
|
245 |
+
|
246 |
+
|
247 |
+
def chinese_to_bopomofo(text):
|
248 |
+
text = text.replace('、', ',').replace(';', ',').replace(':', ',')
|
249 |
+
words = jieba.lcut(text, cut_all=False)
|
250 |
+
text = ''
|
251 |
+
for word in words:
|
252 |
+
bopomofos = lazy_pinyin(word, BOPOMOFO)
|
253 |
+
if not re.search('[\u4e00-\u9fff]', word):
|
254 |
+
text += word
|
255 |
+
continue
|
256 |
+
for i in range(len(bopomofos)):
|
257 |
+
bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
|
258 |
+
if text != '':
|
259 |
+
text += ' '
|
260 |
+
text += ''.join(bopomofos)
|
261 |
+
return text
|
262 |
+
|
263 |
+
|
264 |
+
def latin_to_bopomofo(text):
|
265 |
+
for regex, replacement in _latin_to_bopomofo:
|
266 |
+
text = re.sub(regex, replacement, text)
|
267 |
+
return text
|
268 |
+
|
269 |
+
|
270 |
+
def bopomofo_to_romaji(text):
|
271 |
+
for regex, replacement in _bopomofo_to_romaji:
|
272 |
+
text = re.sub(regex, replacement, text)
|
273 |
+
return text
|
274 |
+
|
275 |
+
|
276 |
+
def bopomofo_to_ipa(text):
|
277 |
+
for regex, replacement in _bopomofo_to_ipa:
|
278 |
+
text = re.sub(regex, replacement, text)
|
279 |
+
return text
|
280 |
+
|
281 |
+
|
282 |
+
def bopomofo_to_ipa2(text):
|
283 |
+
for regex, replacement in _bopomofo_to_ipa2:
|
284 |
+
text = re.sub(regex, replacement, text)
|
285 |
+
return text
|
286 |
+
|
287 |
+
|
288 |
+
def chinese_to_romaji(text):
|
289 |
+
text = number_to_chinese(text)
|
290 |
+
text = chinese_to_bopomofo(text)
|
291 |
+
text = latin_to_bopomofo(text)
|
292 |
+
text = bopomofo_to_romaji(text)
|
293 |
+
text = re.sub('i([aoe])', r'y\1', text)
|
294 |
+
text = re.sub('u([aoəe])', r'w\1', text)
|
295 |
+
text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
|
296 |
+
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
|
297 |
+
text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
|
298 |
+
return text
|
299 |
+
|
300 |
+
|
301 |
+
def chinese_to_lazy_ipa(text):
|
302 |
+
text = chinese_to_romaji(text)
|
303 |
+
for regex, replacement in _romaji_to_ipa:
|
304 |
+
text = re.sub(regex, replacement, text)
|
305 |
+
return text
|
306 |
+
|
307 |
+
|
308 |
+
def chinese_to_ipa(text):
|
309 |
+
text = number_to_chinese(text)
|
310 |
+
text = chinese_to_bopomofo(text)
|
311 |
+
text = latin_to_bopomofo(text)
|
312 |
+
text = bopomofo_to_ipa(text)
|
313 |
+
text = re.sub('i([aoe])', r'j\1', text)
|
314 |
+
text = re.sub('u([aoəe])', r'w\1', text)
|
315 |
+
text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
|
316 |
+
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
|
317 |
+
text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
|
318 |
+
return text
|
319 |
+
|
320 |
+
|
321 |
+
def chinese_to_ipa2(text):
|
322 |
+
text = number_to_chinese(text)
|
323 |
+
text = chinese_to_bopomofo(text)
|
324 |
+
text = latin_to_bopomofo(text)
|
325 |
+
text = bopomofo_to_ipa2(text)
|
326 |
+
text = re.sub(r'i([aoe])', r'j\1', text)
|
327 |
+
text = re.sub(r'u([aoəe])', r'w\1', text)
|
328 |
+
text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
|
329 |
+
text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
|
330 |
+
return text
|
text/ngu_dialect.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import opencc
|
3 |
+
|
4 |
+
|
5 |
+
dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou',
|
6 |
+
'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing',
|
7 |
+
'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang',
|
8 |
+
'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan',
|
9 |
+
'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen',
|
10 |
+
'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'}
|
11 |
+
|
12 |
+
converters = {}
|
13 |
+
|
14 |
+
for dialect in dialects.values():
|
15 |
+
try:
|
16 |
+
converters[dialect] = opencc.OpenCC(dialect)
|
17 |
+
except:
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
def ngu_dialect_to_ipa(text, dialect):
|
22 |
+
dialect = dialects[dialect]
|
23 |
+
text = converters[dialect].convert(text).replace('-','').replace('$',' ')
|
24 |
+
text = re.sub(r'[、;:]', ',', text)
|
25 |
+
text = re.sub(r'\s*,\s*', ', ', text)
|
26 |
+
text = re.sub(r'\s*。\s*', '. ', text)
|
27 |
+
text = re.sub(r'\s*?\s*', '? ', text)
|
28 |
+
text = re.sub(r'\s*!\s*', '! ', text)
|
29 |
+
text = re.sub(r'\s*$', '', text)
|
30 |
+
return text
|
text/sanskrit.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from indic_transliteration import sanscript
|
3 |
+
|
4 |
+
|
5 |
+
# List of (iast, ipa) pairs:
|
6 |
+
_iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
7 |
+
('a', 'ə'),
|
8 |
+
('ā', 'aː'),
|
9 |
+
('ī', 'iː'),
|
10 |
+
('ū', 'uː'),
|
11 |
+
('ṛ', 'ɹ`'),
|
12 |
+
('ṝ', 'ɹ`ː'),
|
13 |
+
('ḷ', 'l`'),
|
14 |
+
('ḹ', 'l`ː'),
|
15 |
+
('e', 'eː'),
|
16 |
+
('o', 'oː'),
|
17 |
+
('k', 'k⁼'),
|
18 |
+
('k⁼h', 'kʰ'),
|
19 |
+
('g', 'g⁼'),
|
20 |
+
('g⁼h', 'gʰ'),
|
21 |
+
('ṅ', 'ŋ'),
|
22 |
+
('c', 'ʧ⁼'),
|
23 |
+
('ʧ⁼h', 'ʧʰ'),
|
24 |
+
('j', 'ʥ⁼'),
|
25 |
+
('ʥ⁼h', 'ʥʰ'),
|
26 |
+
('ñ', 'n^'),
|
27 |
+
('ṭ', 't`⁼'),
|
28 |
+
('t`⁼h', 't`ʰ'),
|
29 |
+
('ḍ', 'd`⁼'),
|
30 |
+
('d`⁼h', 'd`ʰ'),
|
31 |
+
('ṇ', 'n`'),
|
32 |
+
('t', 't⁼'),
|
33 |
+
('t⁼h', 'tʰ'),
|
34 |
+
('d', 'd⁼'),
|
35 |
+
('d⁼h', 'dʰ'),
|
36 |
+
('p', 'p⁼'),
|
37 |
+
('p⁼h', 'pʰ'),
|
38 |
+
('b', 'b⁼'),
|
39 |
+
('b⁼h', 'bʰ'),
|
40 |
+
('y', 'j'),
|
41 |
+
('ś', 'ʃ'),
|
42 |
+
('ṣ', 's`'),
|
43 |
+
('r', 'ɾ'),
|
44 |
+
('l̤', 'l`'),
|
45 |
+
('h', 'ɦ'),
|
46 |
+
("'", ''),
|
47 |
+
('~', '^'),
|
48 |
+
('ṃ', '^')
|
49 |
+
]]
|
50 |
+
|
51 |
+
|
52 |
+
def devanagari_to_ipa(text):
|
53 |
+
text = text.replace('ॐ', 'ओम्')
|
54 |
+
text = re.sub(r'\s*।\s*$', '.', text)
|
55 |
+
text = re.sub(r'\s*।\s*', ', ', text)
|
56 |
+
text = re.sub(r'\s*॥', '.', text)
|
57 |
+
text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST)
|
58 |
+
for regex, replacement in _iast_to_ipa:
|
59 |
+
text = re.sub(regex, replacement, text)
|
60 |
+
text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0)
|
61 |
+
[:-1]+'h'+x.group(1)+'*', text)
|
62 |
+
return text
|
text/shanghainese.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import cn2an
|
3 |
+
import opencc
|
4 |
+
|
5 |
+
|
6 |
+
converter = opencc.OpenCC('zaonhe')
|
7 |
+
|
8 |
+
# List of (Latin alphabet, ipa) pairs:
|
9 |
+
_latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
10 |
+
('A', 'ᴇ'),
|
11 |
+
('B', 'bi'),
|
12 |
+
('C', 'si'),
|
13 |
+
('D', 'di'),
|
14 |
+
('E', 'i'),
|
15 |
+
('F', 'ᴇf'),
|
16 |
+
('G', 'dʑi'),
|
17 |
+
('H', 'ᴇtɕʰ'),
|
18 |
+
('I', 'ᴀi'),
|
19 |
+
('J', 'dʑᴇ'),
|
20 |
+
('K', 'kʰᴇ'),
|
21 |
+
('L', 'ᴇl'),
|
22 |
+
('M', 'ᴇm'),
|
23 |
+
('N', 'ᴇn'),
|
24 |
+
('O', 'o'),
|
25 |
+
('P', 'pʰi'),
|
26 |
+
('Q', 'kʰiu'),
|
27 |
+
('R', 'ᴀl'),
|
28 |
+
('S', 'ᴇs'),
|
29 |
+
('T', 'tʰi'),
|
30 |
+
('U', 'ɦiu'),
|
31 |
+
('V', 'vi'),
|
32 |
+
('W', 'dᴀbɤliu'),
|
33 |
+
('X', 'ᴇks'),
|
34 |
+
('Y', 'uᴀi'),
|
35 |
+
('Z', 'zᴇ')
|
36 |
+
]]
|
37 |
+
|
38 |
+
|
39 |
+
def _number_to_shanghainese(num):
|
40 |
+
num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两')
|
41 |
+
return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num)
|
42 |
+
|
43 |
+
|
44 |
+
def number_to_shanghainese(text):
|
45 |
+
return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text)
|
46 |
+
|
47 |
+
|
48 |
+
def latin_to_ipa(text):
|
49 |
+
for regex, replacement in _latin_to_ipa:
|
50 |
+
text = re.sub(regex, replacement, text)
|
51 |
+
return text
|
52 |
+
|
53 |
+
|
54 |
+
def shanghainese_to_ipa(text):
|
55 |
+
text = number_to_shanghainese(text.upper())
|
56 |
+
text = converter.convert(text).replace('-','').replace('$',' ')
|
57 |
+
text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
|
58 |
+
text = re.sub(r'[、;:]', ',', text)
|
59 |
+
text = re.sub(r'\s*,\s*', ', ', text)
|
60 |
+
text = re.sub(r'\s*。\s*', '. ', text)
|
61 |
+
text = re.sub(r'\s*?\s*', '? ', text)
|
62 |
+
text = re.sub(r'\s*!\s*', '! ', text)
|
63 |
+
text = re.sub(r'\s*$', '', text)
|
64 |
+
return text
|
text/thai.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from num_thai.thainumbers import NumThai
|
3 |
+
|
4 |
+
|
5 |
+
num = NumThai()
|
6 |
+
|
7 |
+
# List of (Latin alphabet, Thai) pairs:
|
8 |
+
_latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
9 |
+
('a', 'เอ'),
|
10 |
+
('b','บี'),
|
11 |
+
('c','ซี'),
|
12 |
+
('d','ดี'),
|
13 |
+
('e','อี'),
|
14 |
+
('f','เอฟ'),
|
15 |
+
('g','จี'),
|
16 |
+
('h','เอช'),
|
17 |
+
('i','ไอ'),
|
18 |
+
('j','เจ'),
|
19 |
+
('k','เค'),
|
20 |
+
('l','แอล'),
|
21 |
+
('m','เอ็ม'),
|
22 |
+
('n','เอ็น'),
|
23 |
+
('o','โอ'),
|
24 |
+
('p','พี'),
|
25 |
+
('q','คิว'),
|
26 |
+
('r','แอร์'),
|
27 |
+
('s','เอส'),
|
28 |
+
('t','ที'),
|
29 |
+
('u','ยู'),
|
30 |
+
('v','วี'),
|
31 |
+
('w','ดับเบิลยู'),
|
32 |
+
('x','เอ็กซ์'),
|
33 |
+
('y','วาย'),
|
34 |
+
('z','ซี')
|
35 |
+
]]
|
36 |
+
|
37 |
+
|
38 |
+
def num_to_thai(text):
|
39 |
+
return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text)
|
40 |
+
|
41 |
+
def latin_to_thai(text):
|
42 |
+
for regex, replacement in _latin_to_thai:
|
43 |
+
text = re.sub(regex, replacement, text)
|
44 |
+
return text
|
transforms.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
10 |
+
|
11 |
+
|
12 |
+
def piecewise_rational_quadratic_transform(inputs,
|
13 |
+
unnormalized_widths,
|
14 |
+
unnormalized_heights,
|
15 |
+
unnormalized_derivatives,
|
16 |
+
inverse=False,
|
17 |
+
tails=None,
|
18 |
+
tail_bound=1.,
|
19 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
20 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
21 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
22 |
+
|
23 |
+
if tails is None:
|
24 |
+
spline_fn = rational_quadratic_spline
|
25 |
+
spline_kwargs = {}
|
26 |
+
else:
|
27 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
28 |
+
spline_kwargs = {
|
29 |
+
'tails': tails,
|
30 |
+
'tail_bound': tail_bound
|
31 |
+
}
|
32 |
+
|
33 |
+
outputs, logabsdet = spline_fn(
|
34 |
+
inputs=inputs,
|
35 |
+
unnormalized_widths=unnormalized_widths,
|
36 |
+
unnormalized_heights=unnormalized_heights,
|
37 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
38 |
+
inverse=inverse,
|
39 |
+
min_bin_width=min_bin_width,
|
40 |
+
min_bin_height=min_bin_height,
|
41 |
+
min_derivative=min_derivative,
|
42 |
+
**spline_kwargs
|
43 |
+
)
|
44 |
+
return outputs, logabsdet
|
45 |
+
|
46 |
+
|
47 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
48 |
+
bin_locations[..., -1] += eps
|
49 |
+
return torch.sum(
|
50 |
+
inputs[..., None] >= bin_locations,
|
51 |
+
dim=-1
|
52 |
+
) - 1
|
53 |
+
|
54 |
+
|
55 |
+
def unconstrained_rational_quadratic_spline(inputs,
|
56 |
+
unnormalized_widths,
|
57 |
+
unnormalized_heights,
|
58 |
+
unnormalized_derivatives,
|
59 |
+
inverse=False,
|
60 |
+
tails='linear',
|
61 |
+
tail_bound=1.,
|
62 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
63 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
64 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
65 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
66 |
+
outside_interval_mask = ~inside_interval_mask
|
67 |
+
|
68 |
+
outputs = torch.zeros_like(inputs)
|
69 |
+
logabsdet = torch.zeros_like(inputs)
|
70 |
+
|
71 |
+
if tails == 'linear':
|
72 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
73 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
74 |
+
unnormalized_derivatives[..., 0] = constant
|
75 |
+
unnormalized_derivatives[..., -1] = constant
|
76 |
+
|
77 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
78 |
+
logabsdet[outside_interval_mask] = 0
|
79 |
+
else:
|
80 |
+
raise RuntimeError('{} tails are not implemented.'.format(tails))
|
81 |
+
|
82 |
+
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
83 |
+
inputs=inputs[inside_interval_mask],
|
84 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
85 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
86 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
87 |
+
inverse=inverse,
|
88 |
+
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
|
89 |
+
min_bin_width=min_bin_width,
|
90 |
+
min_bin_height=min_bin_height,
|
91 |
+
min_derivative=min_derivative
|
92 |
+
)
|
93 |
+
|
94 |
+
return outputs, logabsdet
|
95 |
+
|
96 |
+
def rational_quadratic_spline(inputs,
|
97 |
+
unnormalized_widths,
|
98 |
+
unnormalized_heights,
|
99 |
+
unnormalized_derivatives,
|
100 |
+
inverse=False,
|
101 |
+
left=0., right=1., bottom=0., top=1.,
|
102 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
103 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
104 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
105 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
106 |
+
raise ValueError('Input to a transform is not within its domain')
|
107 |
+
|
108 |
+
num_bins = unnormalized_widths.shape[-1]
|
109 |
+
|
110 |
+
if min_bin_width * num_bins > 1.0:
|
111 |
+
raise ValueError('Minimal bin width too large for the number of bins')
|
112 |
+
if min_bin_height * num_bins > 1.0:
|
113 |
+
raise ValueError('Minimal bin height too large for the number of bins')
|
114 |
+
|
115 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
116 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
117 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
118 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
|
119 |
+
cumwidths = (right - left) * cumwidths + left
|
120 |
+
cumwidths[..., 0] = left
|
121 |
+
cumwidths[..., -1] = right
|
122 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
123 |
+
|
124 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
125 |
+
|
126 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
127 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
128 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
129 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
|
130 |
+
cumheights = (top - bottom) * cumheights + bottom
|
131 |
+
cumheights[..., 0] = bottom
|
132 |
+
cumheights[..., -1] = top
|
133 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
134 |
+
|
135 |
+
if inverse:
|
136 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
137 |
+
else:
|
138 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
139 |
+
|
140 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
141 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
142 |
+
|
143 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
144 |
+
delta = heights / widths
|
145 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
146 |
+
|
147 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
148 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
149 |
+
|
150 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
151 |
+
|
152 |
+
if inverse:
|
153 |
+
a = (((inputs - input_cumheights) * (input_derivatives
|
154 |
+
+ input_derivatives_plus_one
|
155 |
+
- 2 * input_delta)
|
156 |
+
+ input_heights * (input_delta - input_derivatives)))
|
157 |
+
b = (input_heights * input_derivatives
|
158 |
+
- (inputs - input_cumheights) * (input_derivatives
|
159 |
+
+ input_derivatives_plus_one
|
160 |
+
- 2 * input_delta))
|
161 |
+
c = - input_delta * (inputs - input_cumheights)
|
162 |
+
|
163 |
+
discriminant = b.pow(2) - 4 * a * c
|
164 |
+
assert (discriminant >= 0).all()
|
165 |
+
|
166 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
167 |
+
outputs = root * input_bin_widths + input_cumwidths
|
168 |
+
|
169 |
+
theta_one_minus_theta = root * (1 - root)
|
170 |
+
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
171 |
+
* theta_one_minus_theta)
|
172 |
+
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
|
173 |
+
+ 2 * input_delta * theta_one_minus_theta
|
174 |
+
+ input_derivatives * (1 - root).pow(2))
|
175 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
176 |
+
|
177 |
+
return outputs, -logabsdet
|
178 |
+
else:
|
179 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
180 |
+
theta_one_minus_theta = theta * (1 - theta)
|
181 |
+
|
182 |
+
numerator = input_heights * (input_delta * theta.pow(2)
|
183 |
+
+ input_derivatives * theta_one_minus_theta)
|
184 |
+
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
185 |
+
* theta_one_minus_theta)
|
186 |
+
outputs = input_cumheights + numerator / denominator
|
187 |
+
|
188 |
+
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
|
189 |
+
+ 2 * input_delta * theta_one_minus_theta
|
190 |
+
+ input_derivatives * (1 - theta).pow(2))
|
191 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
192 |
+
|
193 |
+
return outputs, logabsdet
|
utils_vits.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from json import loads
|
3 |
+
from torch import load, FloatTensor
|
4 |
+
from numpy import float32
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
|
8 |
+
class HParams():
|
9 |
+
def __init__(self, **kwargs):
|
10 |
+
for k, v in kwargs.items():
|
11 |
+
if type(v) == dict:
|
12 |
+
v = HParams(**v)
|
13 |
+
self[k] = v
|
14 |
+
|
15 |
+
def keys(self):
|
16 |
+
return self.__dict__.keys()
|
17 |
+
|
18 |
+
def items(self):
|
19 |
+
return self.__dict__.items()
|
20 |
+
|
21 |
+
def values(self):
|
22 |
+
return self.__dict__.values()
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.__dict__)
|
26 |
+
|
27 |
+
def __getitem__(self, key):
|
28 |
+
return getattr(self, key)
|
29 |
+
|
30 |
+
def __setitem__(self, key, value):
|
31 |
+
return setattr(self, key, value)
|
32 |
+
|
33 |
+
def __contains__(self, key):
|
34 |
+
return key in self.__dict__
|
35 |
+
|
36 |
+
def __repr__(self):
|
37 |
+
return self.__dict__.__repr__()
|
38 |
+
|
39 |
+
|
40 |
+
def load_checkpoint(checkpoint_path, model):
|
41 |
+
checkpoint_dict = load(checkpoint_path, map_location='cpu')
|
42 |
+
iteration = checkpoint_dict['iteration']
|
43 |
+
saved_state_dict = checkpoint_dict['model']
|
44 |
+
if hasattr(model, 'module'):
|
45 |
+
state_dict = model.module.state_dict()
|
46 |
+
else:
|
47 |
+
state_dict = model.state_dict()
|
48 |
+
new_state_dict= {}
|
49 |
+
for k, v in state_dict.items():
|
50 |
+
try:
|
51 |
+
new_state_dict[k] = saved_state_dict[k]
|
52 |
+
except:
|
53 |
+
logging.info("%s is not in the checkpoint" % k)
|
54 |
+
new_state_dict[k] = v
|
55 |
+
if hasattr(model, 'module'):
|
56 |
+
model.module.load_state_dict(new_state_dict)
|
57 |
+
else:
|
58 |
+
model.load_state_dict(new_state_dict)
|
59 |
+
logging.info("Loaded checkpoint '{}' (iteration {})" .format(
|
60 |
+
checkpoint_path, iteration))
|
61 |
+
return
|
62 |
+
|
63 |
+
|
64 |
+
def get_hparams_from_file(config_path):
|
65 |
+
with open(config_path, "r") as f:
|
66 |
+
data = f.read()
|
67 |
+
config = loads(data)
|
68 |
+
|
69 |
+
hparams = HParams(**config)
|
70 |
+
return hparams
|
71 |
+
|
72 |
+
|
73 |
+
def load_audio_to_torch(full_path, target_sampling_rate):
|
74 |
+
audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
|
75 |
+
return FloatTensor(audio.astype(float32))
|
visual_chatgpt.py
ADDED
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
import re
|
7 |
+
import uuid
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
|
13 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
14 |
+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
15 |
+
|
16 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
|
17 |
+
from diffusers import EulerAncestralDiscreteScheduler
|
18 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
19 |
+
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
|
20 |
+
|
21 |
+
from langchain.agents.initialize import initialize_agent
|
22 |
+
from langchain.agents.tools import Tool
|
23 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
24 |
+
from langchain.llms.openai import OpenAI
|
25 |
+
|
26 |
+
VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
27 |
+
|
28 |
+
Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
|
29 |
+
|
30 |
+
Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
|
31 |
+
|
32 |
+
Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
|
33 |
+
|
34 |
+
|
35 |
+
TOOLS:
|
36 |
+
------
|
37 |
+
|
38 |
+
Visual ChatGPT has access to the following tools:"""
|
39 |
+
|
40 |
+
VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
|
41 |
+
|
42 |
+
```
|
43 |
+
Thought: Do I need to use a tool? Yes
|
44 |
+
Action: the action to take, should be one of [{tool_names}]
|
45 |
+
Action Input: the input to the action
|
46 |
+
Observation: the result of the action
|
47 |
+
```
|
48 |
+
|
49 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
50 |
+
|
51 |
+
```
|
52 |
+
Thought: Do I need to use a tool? No
|
53 |
+
{ai_prefix}: [your response here]
|
54 |
+
```
|
55 |
+
"""
|
56 |
+
|
57 |
+
VISUAL_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if it does not exist.
|
58 |
+
You will remember to provide the image file name loyally if it's provided in the last tool observation.
|
59 |
+
|
60 |
+
Begin!
|
61 |
+
|
62 |
+
Previous conversation history:
|
63 |
+
{chat_history}
|
64 |
+
|
65 |
+
New input: {input}
|
66 |
+
Since Visual ChatGPT is a text language model, Visual ChatGPT must use tools to observe images rather than imagination.
|
67 |
+
The thoughts and observations are only visible for Visual ChatGPT, Visual ChatGPT should remember to repeat important information in the final response for Human.
|
68 |
+
Thought: Do I need to use a tool? {agent_scratchpad}"""
|
69 |
+
|
70 |
+
os.makedirs('image', exist_ok=True)
|
71 |
+
|
72 |
+
|
73 |
+
def seed_everything(seed):
|
74 |
+
random.seed(seed)
|
75 |
+
np.random.seed(seed)
|
76 |
+
torch.manual_seed(seed)
|
77 |
+
torch.cuda.manual_seed_all(seed)
|
78 |
+
return seed
|
79 |
+
|
80 |
+
|
81 |
+
def prompts(name, description):
|
82 |
+
def decorator(func):
|
83 |
+
func.name = name
|
84 |
+
func.description = description
|
85 |
+
return func
|
86 |
+
|
87 |
+
return decorator
|
88 |
+
|
89 |
+
|
90 |
+
def cut_dialogue_history(history_memory, keep_last_n_words=500):
|
91 |
+
tokens = history_memory.split()
|
92 |
+
n_tokens = len(tokens)
|
93 |
+
print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
|
94 |
+
if n_tokens < keep_last_n_words:
|
95 |
+
return history_memory
|
96 |
+
else:
|
97 |
+
paragraphs = history_memory.split('\n')
|
98 |
+
last_n_tokens = n_tokens
|
99 |
+
while last_n_tokens >= keep_last_n_words:
|
100 |
+
last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
|
101 |
+
paragraphs = paragraphs[1:]
|
102 |
+
return '\n' + '\n'.join(paragraphs)
|
103 |
+
|
104 |
+
|
105 |
+
def get_new_image_name(org_img_name, func_name="update"):
|
106 |
+
head_tail = os.path.split(org_img_name)
|
107 |
+
head = head_tail[0]
|
108 |
+
tail = head_tail[1]
|
109 |
+
name_split = tail.split('.')[0].split('_')
|
110 |
+
this_new_uuid = str(uuid.uuid4())[0:4]
|
111 |
+
if len(name_split) == 1:
|
112 |
+
most_org_file_name = name_split[0]
|
113 |
+
recent_prev_file_name = name_split[0]
|
114 |
+
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
|
115 |
+
else:
|
116 |
+
assert len(name_split) == 4
|
117 |
+
most_org_file_name = name_split[3]
|
118 |
+
recent_prev_file_name = name_split[0]
|
119 |
+
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
|
120 |
+
return os.path.join(head, new_file_name)
|
121 |
+
|
122 |
+
|
123 |
+
class MaskFormer:
|
124 |
+
def __init__(self, device):
|
125 |
+
print("Initializing MaskFormer to %s" % device)
|
126 |
+
self.device = device
|
127 |
+
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
128 |
+
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
|
129 |
+
|
130 |
+
def inference(self, image_path, text):
|
131 |
+
threshold = 0.5
|
132 |
+
min_area = 0.02
|
133 |
+
padding = 20
|
134 |
+
original_image = Image.open(image_path)
|
135 |
+
image = original_image.resize((512, 512))
|
136 |
+
inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
|
137 |
+
with torch.no_grad():
|
138 |
+
outputs = self.model(**inputs)
|
139 |
+
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
|
140 |
+
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
|
141 |
+
if area_ratio < min_area:
|
142 |
+
return None
|
143 |
+
true_indices = np.argwhere(mask)
|
144 |
+
mask_array = np.zeros_like(mask, dtype=bool)
|
145 |
+
for idx in true_indices:
|
146 |
+
padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
|
147 |
+
mask_array[padded_slice] = True
|
148 |
+
visual_mask = (mask_array * 255).astype(np.uint8)
|
149 |
+
image_mask = Image.fromarray(visual_mask)
|
150 |
+
return image_mask.resize(original_image.size)
|
151 |
+
|
152 |
+
|
153 |
+
class ImageEditing:
|
154 |
+
def __init__(self, device):
|
155 |
+
print("Initializing ImageEditing to %s" % device)
|
156 |
+
self.device = device
|
157 |
+
self.mask_former = MaskFormer(device=self.device)
|
158 |
+
self.revision = 'fp16' if 'cuda' in device else None
|
159 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
160 |
+
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
161 |
+
"runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
|
162 |
+
|
163 |
+
@prompts(name="Remove Something From The Photo",
|
164 |
+
description="useful when you want to remove and object or something from the photo "
|
165 |
+
"from its description or location. "
|
166 |
+
"The input to this tool should be a comma seperated string of two, "
|
167 |
+
"representing the image_path and the object need to be removed. ")
|
168 |
+
def inference_remove(self, inputs):
|
169 |
+
image_path, to_be_removed_txt = inputs.split(",")
|
170 |
+
return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
|
171 |
+
|
172 |
+
@prompts(name="Replace Something From The Photo",
|
173 |
+
description="useful when you want to replace an object from the object description or "
|
174 |
+
"location with another object from its description. "
|
175 |
+
"The input to this tool should be a comma seperated string of three, "
|
176 |
+
"representing the image_path, the object to be replaced, the object to be replaced with ")
|
177 |
+
def inference_replace(self, inputs):
|
178 |
+
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
|
179 |
+
original_image = Image.open(image_path)
|
180 |
+
original_size = original_image.size
|
181 |
+
mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
|
182 |
+
updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)),
|
183 |
+
mask_image=mask_image.resize((512, 512))).images[0]
|
184 |
+
updated_image_path = get_new_image_name(image_path, func_name="replace-something")
|
185 |
+
updated_image = updated_image.resize(original_size)
|
186 |
+
updated_image.save(updated_image_path)
|
187 |
+
print(
|
188 |
+
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
|
189 |
+
f"Output Image: {updated_image_path}")
|
190 |
+
return updated_image_path
|
191 |
+
|
192 |
+
|
193 |
+
class InstructPix2Pix:
|
194 |
+
def __init__(self, device):
|
195 |
+
print("Initializing InstructPix2Pix to %s" % device)
|
196 |
+
self.device = device
|
197 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
198 |
+
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
|
199 |
+
safety_checker=None,
|
200 |
+
torch_dtype=self.torch_dtype).to(device)
|
201 |
+
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
202 |
+
|
203 |
+
@prompts(name="Instruct Image Using Text",
|
204 |
+
description="useful when you want to the style of the image to be like the text. "
|
205 |
+
"like: make it look like a painting. or make it like a robot. "
|
206 |
+
"The input to this tool should be a comma seperated string of two, "
|
207 |
+
"representing the image_path and the text. ")
|
208 |
+
def inference(self, inputs):
|
209 |
+
"""Change style of image."""
|
210 |
+
print("===>Starting InstructPix2Pix Inference")
|
211 |
+
image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
212 |
+
original_image = Image.open(image_path)
|
213 |
+
image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
|
214 |
+
updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
|
215 |
+
image.save(updated_image_path)
|
216 |
+
print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
|
217 |
+
f"Output Image: {updated_image_path}")
|
218 |
+
return updated_image_path
|
219 |
+
|
220 |
+
|
221 |
+
class Text2Image:
|
222 |
+
def __init__(self, device):
|
223 |
+
print("Initializing Text2Image to %s" % device)
|
224 |
+
self.device = device
|
225 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
226 |
+
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
|
227 |
+
torch_dtype=self.torch_dtype)
|
228 |
+
self.pipe.to(device)
|
229 |
+
self.a_prompt = 'best quality, extremely detailed'
|
230 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
231 |
+
'fewer digits, cropped, worst quality, low quality'
|
232 |
+
|
233 |
+
@prompts(name="Generate Image From User Input Text",
|
234 |
+
description="useful when you want to generate an image from a user input text and save it to a file. "
|
235 |
+
"like: generate an image of an object or something, or generate an image that includes some objects. "
|
236 |
+
"The input to this tool should be a string, representing the text used to generate image. ")
|
237 |
+
def inference(self, text):
|
238 |
+
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
239 |
+
prompt = text + ', ' + self.a_prompt
|
240 |
+
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
|
241 |
+
image.save(image_filename)
|
242 |
+
print(
|
243 |
+
f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
|
244 |
+
return image_filename
|
245 |
+
|
246 |
+
|
247 |
+
class ImageCaptioning:
|
248 |
+
def __init__(self, device):
|
249 |
+
print("Initializing ImageCaptioning to %s" % device)
|
250 |
+
self.device = device
|
251 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
252 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
253 |
+
self.model = BlipForConditionalGeneration.from_pretrained(
|
254 |
+
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
|
255 |
+
|
256 |
+
@prompts(name="Get Photo Description",
|
257 |
+
description="useful when you want to know what is inside the photo. receives image_path as input. "
|
258 |
+
"The input to this tool should be a string, representing the image_path. ")
|
259 |
+
def inference(self, image_path):
|
260 |
+
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
|
261 |
+
out = self.model.generate(**inputs)
|
262 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
263 |
+
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
|
264 |
+
return captions
|
265 |
+
|
266 |
+
|
267 |
+
class Image2Canny:
|
268 |
+
def __init__(self, device):
|
269 |
+
print("Initializing Image2Canny")
|
270 |
+
self.low_threshold = 100
|
271 |
+
self.high_threshold = 200
|
272 |
+
|
273 |
+
@prompts(name="Edge Detection On Image",
|
274 |
+
description="useful when you want to detect the edge of the image. "
|
275 |
+
"like: detect the edges of this image, or canny detection on image, "
|
276 |
+
"or perform edge detection on this image, or detect the canny image of this image. "
|
277 |
+
"The input to this tool should be a string, representing the image_path")
|
278 |
+
def inference(self, inputs):
|
279 |
+
image = Image.open(inputs)
|
280 |
+
image = np.array(image)
|
281 |
+
canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
|
282 |
+
canny = canny[:, :, None]
|
283 |
+
canny = np.concatenate([canny, canny, canny], axis=2)
|
284 |
+
canny = Image.fromarray(canny)
|
285 |
+
updated_image_path = get_new_image_name(inputs, func_name="edge")
|
286 |
+
canny.save(updated_image_path)
|
287 |
+
print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}")
|
288 |
+
return updated_image_path
|
289 |
+
|
290 |
+
|
291 |
+
class CannyText2Image:
|
292 |
+
def __init__(self, device):
|
293 |
+
print("Initializing CannyText2Image to %s" % device)
|
294 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
295 |
+
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny",
|
296 |
+
torch_dtype=self.torch_dtype)
|
297 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
298 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
299 |
+
torch_dtype=self.torch_dtype)
|
300 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
301 |
+
self.pipe.to(device)
|
302 |
+
self.seed = -1
|
303 |
+
self.a_prompt = 'best quality, extremely detailed'
|
304 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
305 |
+
'fewer digits, cropped, worst quality, low quality'
|
306 |
+
|
307 |
+
@prompts(name="Generate Image Condition On Canny Image",
|
308 |
+
description="useful when you want to generate a new real image from both the user desciption and a canny image."
|
309 |
+
" like: generate a real image of a object or something from this canny image,"
|
310 |
+
" or generate a new real image of a object or something from this edge image. "
|
311 |
+
"The input to this tool should be a comma seperated string of two, "
|
312 |
+
"representing the image_path and the user description. ")
|
313 |
+
def inference(self, inputs):
|
314 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
315 |
+
image = Image.open(image_path)
|
316 |
+
self.seed = random.randint(0, 65535)
|
317 |
+
seed_everything(self.seed)
|
318 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
319 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
320 |
+
guidance_scale=9.0).images[0]
|
321 |
+
updated_image_path = get_new_image_name(image_path, func_name="canny2image")
|
322 |
+
image.save(updated_image_path)
|
323 |
+
print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, "
|
324 |
+
f"Output Text: {updated_image_path}")
|
325 |
+
return updated_image_path
|
326 |
+
|
327 |
+
|
328 |
+
class Image2Line:
|
329 |
+
def __init__(self, device):
|
330 |
+
print("Initializing Image2Line")
|
331 |
+
self.detector = MLSDdetector.from_pretrained('lllyasviel/ControlNet')
|
332 |
+
|
333 |
+
@prompts(name="Line Detection On Image",
|
334 |
+
description="useful when you want to detect the straight line of the image. "
|
335 |
+
"like: detect the straight lines of this image, or straight line detection on image, "
|
336 |
+
"or peform straight line detection on this image, or detect the straight line image of this image. "
|
337 |
+
"The input to this tool should be a string, representing the image_path")
|
338 |
+
def inference(self, inputs):
|
339 |
+
image = Image.open(inputs)
|
340 |
+
mlsd = self.detector(image)
|
341 |
+
updated_image_path = get_new_image_name(inputs, func_name="line-of")
|
342 |
+
mlsd.save(updated_image_path)
|
343 |
+
print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}")
|
344 |
+
return updated_image_path
|
345 |
+
|
346 |
+
|
347 |
+
class LineText2Image:
|
348 |
+
def __init__(self, device):
|
349 |
+
print("Initializing LineText2Image to %s" % device)
|
350 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
351 |
+
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd",
|
352 |
+
torch_dtype=self.torch_dtype)
|
353 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
354 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
355 |
+
torch_dtype=self.torch_dtype
|
356 |
+
)
|
357 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
358 |
+
self.pipe.to(device)
|
359 |
+
self.seed = -1
|
360 |
+
self.a_prompt = 'best quality, extremely detailed'
|
361 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
362 |
+
'fewer digits, cropped, worst quality, low quality'
|
363 |
+
|
364 |
+
@prompts(name="Generate Image Condition On Line Image",
|
365 |
+
description="useful when you want to generate a new real image from both the user desciption "
|
366 |
+
"and a straight line image. "
|
367 |
+
"like: generate a real image of a object or something from this straight line image, "
|
368 |
+
"or generate a new real image of a object or something from this straight lines. "
|
369 |
+
"The input to this tool should be a comma seperated string of two, "
|
370 |
+
"representing the image_path and the user description. ")
|
371 |
+
def inference(self, inputs):
|
372 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
373 |
+
image = Image.open(image_path)
|
374 |
+
self.seed = random.randint(0, 65535)
|
375 |
+
seed_everything(self.seed)
|
376 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
377 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
378 |
+
guidance_scale=9.0).images[0]
|
379 |
+
updated_image_path = get_new_image_name(image_path, func_name="line2image")
|
380 |
+
image.save(updated_image_path)
|
381 |
+
print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, "
|
382 |
+
f"Output Text: {updated_image_path}")
|
383 |
+
return updated_image_path
|
384 |
+
|
385 |
+
|
386 |
+
class Image2Hed:
|
387 |
+
def __init__(self, device):
|
388 |
+
print("Initializing Image2Hed")
|
389 |
+
self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
|
390 |
+
|
391 |
+
@prompts(name="Hed Detection On Image",
|
392 |
+
description="useful when you want to detect the soft hed boundary of the image. "
|
393 |
+
"like: detect the soft hed boundary of this image, or hed boundary detection on image, "
|
394 |
+
"or peform hed boundary detection on this image, or detect soft hed boundary image of this image. "
|
395 |
+
"The input to this tool should be a string, representing the image_path")
|
396 |
+
def inference(self, inputs):
|
397 |
+
image = Image.open(inputs)
|
398 |
+
hed = self.detector(image)
|
399 |
+
updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
|
400 |
+
hed.save(updated_image_path)
|
401 |
+
print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}")
|
402 |
+
return updated_image_path
|
403 |
+
|
404 |
+
|
405 |
+
class HedText2Image:
|
406 |
+
def __init__(self, device):
|
407 |
+
print("Initializing HedText2Image to %s" % device)
|
408 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
409 |
+
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed",
|
410 |
+
torch_dtype=self.torch_dtype)
|
411 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
412 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
413 |
+
torch_dtype=self.torch_dtype
|
414 |
+
)
|
415 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
416 |
+
self.pipe.to(device)
|
417 |
+
self.seed = -1
|
418 |
+
self.a_prompt = 'best quality, extremely detailed'
|
419 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
420 |
+
'fewer digits, cropped, worst quality, low quality'
|
421 |
+
|
422 |
+
@prompts(name="Generate Image Condition On Soft Hed Boundary Image",
|
423 |
+
description="useful when you want to generate a new real image from both the user desciption "
|
424 |
+
"and a soft hed boundary image. "
|
425 |
+
"like: generate a real image of a object or something from this soft hed boundary image, "
|
426 |
+
"or generate a new real image of a object or something from this hed boundary. "
|
427 |
+
"The input to this tool should be a comma seperated string of two, "
|
428 |
+
"representing the image_path and the user description")
|
429 |
+
def inference(self, inputs):
|
430 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
431 |
+
image = Image.open(image_path)
|
432 |
+
self.seed = random.randint(0, 65535)
|
433 |
+
seed_everything(self.seed)
|
434 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
435 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
436 |
+
guidance_scale=9.0).images[0]
|
437 |
+
updated_image_path = get_new_image_name(image_path, func_name="hed2image")
|
438 |
+
image.save(updated_image_path)
|
439 |
+
print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, "
|
440 |
+
f"Output Image: {updated_image_path}")
|
441 |
+
return updated_image_path
|
442 |
+
|
443 |
+
|
444 |
+
class Image2Scribble:
|
445 |
+
def __init__(self, device):
|
446 |
+
print("Initializing Image2Scribble")
|
447 |
+
self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
|
448 |
+
|
449 |
+
@prompts(name="Sketch Detection On Image",
|
450 |
+
description="useful when you want to generate a scribble of the image. "
|
451 |
+
"like: generate a scribble of this image, or generate a sketch from this image, "
|
452 |
+
"detect the sketch from this image. "
|
453 |
+
"The input to this tool should be a string, representing the image_path")
|
454 |
+
def inference(self, inputs):
|
455 |
+
image = Image.open(inputs)
|
456 |
+
scribble = self.detector(image, scribble=True)
|
457 |
+
updated_image_path = get_new_image_name(inputs, func_name="scribble")
|
458 |
+
scribble.save(updated_image_path)
|
459 |
+
print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
|
460 |
+
return updated_image_path
|
461 |
+
|
462 |
+
|
463 |
+
class ScribbleText2Image:
|
464 |
+
def __init__(self, device):
|
465 |
+
print("Initializing ScribbleText2Image to %s" % device)
|
466 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
467 |
+
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble",
|
468 |
+
torch_dtype=self.torch_dtype)
|
469 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
470 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
471 |
+
torch_dtype=self.torch_dtype
|
472 |
+
)
|
473 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
474 |
+
self.pipe.to(device)
|
475 |
+
self.seed = -1
|
476 |
+
self.a_prompt = 'best quality, extremely detailed'
|
477 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
478 |
+
'fewer digits, cropped, worst quality, low quality'
|
479 |
+
|
480 |
+
@prompts(name="Generate Image Condition On Sketch Image",
|
481 |
+
description="useful when you want to generate a new real image from both the user desciption and "
|
482 |
+
"a scribble image or a sketch image. "
|
483 |
+
"The input to this tool should be a comma seperated string of two, "
|
484 |
+
"representing the image_path and the user description")
|
485 |
+
def inference(self, inputs):
|
486 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
487 |
+
image = Image.open(image_path)
|
488 |
+
self.seed = random.randint(0, 65535)
|
489 |
+
seed_everything(self.seed)
|
490 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
491 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
492 |
+
guidance_scale=9.0).images[0]
|
493 |
+
updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
|
494 |
+
image.save(updated_image_path)
|
495 |
+
print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
|
496 |
+
f"Output Image: {updated_image_path}")
|
497 |
+
return updated_image_path
|
498 |
+
|
499 |
+
|
500 |
+
class Image2Pose:
|
501 |
+
def __init__(self, device):
|
502 |
+
print("Initializing Image2Pose")
|
503 |
+
self.detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
|
504 |
+
|
505 |
+
@prompts(name="Pose Detection On Image",
|
506 |
+
description="useful when you want to detect the human pose of the image. "
|
507 |
+
"like: generate human poses of this image, or generate a pose image from this image. "
|
508 |
+
"The input to this tool should be a string, representing the image_path")
|
509 |
+
def inference(self, inputs):
|
510 |
+
image = Image.open(inputs)
|
511 |
+
pose = self.detector(image)
|
512 |
+
updated_image_path = get_new_image_name(inputs, func_name="human-pose")
|
513 |
+
pose.save(updated_image_path)
|
514 |
+
print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
|
515 |
+
return updated_image_path
|
516 |
+
|
517 |
+
|
518 |
+
class PoseText2Image:
|
519 |
+
def __init__(self, device):
|
520 |
+
print("Initializing PoseText2Image to %s" % device)
|
521 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
522 |
+
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose",
|
523 |
+
torch_dtype=self.torch_dtype)
|
524 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
525 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
526 |
+
torch_dtype=self.torch_dtype)
|
527 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
528 |
+
self.pipe.to(device)
|
529 |
+
self.num_inference_steps = 20
|
530 |
+
self.seed = -1
|
531 |
+
self.unconditional_guidance_scale = 9.0
|
532 |
+
self.a_prompt = 'best quality, extremely detailed'
|
533 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
534 |
+
' fewer digits, cropped, worst quality, low quality'
|
535 |
+
|
536 |
+
@prompts(name="Generate Image Condition On Pose Image",
|
537 |
+
description="useful when you want to generate a new real image from both the user desciption "
|
538 |
+
"and a human pose image. "
|
539 |
+
"like: generate a real image of a human from this human pose image, "
|
540 |
+
"or generate a new real image of a human from this pose. "
|
541 |
+
"The input to this tool should be a comma seperated string of two, "
|
542 |
+
"representing the image_path and the user description")
|
543 |
+
def inference(self, inputs):
|
544 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
545 |
+
image = Image.open(image_path)
|
546 |
+
self.seed = random.randint(0, 65535)
|
547 |
+
seed_everything(self.seed)
|
548 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
549 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
550 |
+
guidance_scale=9.0).images[0]
|
551 |
+
updated_image_path = get_new_image_name(image_path, func_name="pose2image")
|
552 |
+
image.save(updated_image_path)
|
553 |
+
print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, "
|
554 |
+
f"Output Image: {updated_image_path}")
|
555 |
+
return updated_image_path
|
556 |
+
|
557 |
+
|
558 |
+
class Image2Seg:
|
559 |
+
def __init__(self, device):
|
560 |
+
print("Initializing Image2Seg")
|
561 |
+
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
|
562 |
+
self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
|
563 |
+
self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
564 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
565 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
566 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
567 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
568 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
569 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
570 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
571 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
572 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
573 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
574 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
575 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
576 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
577 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
578 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
579 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
580 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
581 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
582 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
583 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
584 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
585 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
586 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
587 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
588 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
589 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
590 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
591 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
592 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
593 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
594 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
595 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
596 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
597 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
598 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
599 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
600 |
+
[102, 255, 0], [92, 0, 255]]
|
601 |
+
|
602 |
+
@prompts(name="Segmentation On Image",
|
603 |
+
description="useful when you want to detect segmentations of the image. "
|
604 |
+
"like: segment this image, or generate segmentations on this image, "
|
605 |
+
"or peform segmentation on this image. "
|
606 |
+
"The input to this tool should be a string, representing the image_path")
|
607 |
+
def inference(self, inputs):
|
608 |
+
image = Image.open(inputs)
|
609 |
+
pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
|
610 |
+
with torch.no_grad():
|
611 |
+
outputs = self.image_segmentor(pixel_values)
|
612 |
+
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
613 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
|
614 |
+
palette = np.array(self.ade_palette)
|
615 |
+
for label, color in enumerate(palette):
|
616 |
+
color_seg[seg == label, :] = color
|
617 |
+
color_seg = color_seg.astype(np.uint8)
|
618 |
+
segmentation = Image.fromarray(color_seg)
|
619 |
+
updated_image_path = get_new_image_name(inputs, func_name="segmentation")
|
620 |
+
segmentation.save(updated_image_path)
|
621 |
+
print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
|
622 |
+
return updated_image_path
|
623 |
+
|
624 |
+
|
625 |
+
class SegText2Image:
|
626 |
+
def __init__(self, device):
|
627 |
+
print("Initializing SegText2Image to %s" % device)
|
628 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
629 |
+
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg",
|
630 |
+
torch_dtype=self.torch_dtype)
|
631 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
632 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
633 |
+
torch_dtype=self.torch_dtype)
|
634 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
635 |
+
self.pipe.to(device)
|
636 |
+
self.seed = -1
|
637 |
+
self.a_prompt = 'best quality, extremely detailed'
|
638 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
639 |
+
' fewer digits, cropped, worst quality, low quality'
|
640 |
+
|
641 |
+
@prompts(name="Generate Image Condition On Segmentations",
|
642 |
+
description="useful when you want to generate a new real image from both the user desciption and segmentations. "
|
643 |
+
"like: generate a real image of a object or something from this segmentation image, "
|
644 |
+
"or generate a new real image of a object or something from these segmentations. "
|
645 |
+
"The input to this tool should be a comma seperated string of two, "
|
646 |
+
"representing the image_path and the user description")
|
647 |
+
def inference(self, inputs):
|
648 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
649 |
+
image = Image.open(image_path)
|
650 |
+
self.seed = random.randint(0, 65535)
|
651 |
+
seed_everything(self.seed)
|
652 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
653 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
654 |
+
guidance_scale=9.0).images[0]
|
655 |
+
updated_image_path = get_new_image_name(image_path, func_name="segment2image")
|
656 |
+
image.save(updated_image_path)
|
657 |
+
print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, "
|
658 |
+
f"Output Image: {updated_image_path}")
|
659 |
+
return updated_image_path
|
660 |
+
|
661 |
+
|
662 |
+
class Image2Depth:
|
663 |
+
def __init__(self, device):
|
664 |
+
print("Initializing Image2Depth")
|
665 |
+
self.depth_estimator = pipeline('depth-estimation')
|
666 |
+
|
667 |
+
@prompts(name="Predict Depth On Image",
|
668 |
+
description="useful when you want to detect depth of the image. like: generate the depth from this image, "
|
669 |
+
"or detect the depth map on this image, or predict the depth for this image. "
|
670 |
+
"The input to this tool should be a string, representing the image_path")
|
671 |
+
def inference(self, inputs):
|
672 |
+
image = Image.open(inputs)
|
673 |
+
depth = self.depth_estimator(image)['depth']
|
674 |
+
depth = np.array(depth)
|
675 |
+
depth = depth[:, :, None]
|
676 |
+
depth = np.concatenate([depth, depth, depth], axis=2)
|
677 |
+
depth = Image.fromarray(depth)
|
678 |
+
updated_image_path = get_new_image_name(inputs, func_name="depth")
|
679 |
+
depth.save(updated_image_path)
|
680 |
+
print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}")
|
681 |
+
return updated_image_path
|
682 |
+
|
683 |
+
|
684 |
+
class DepthText2Image:
|
685 |
+
def __init__(self, device):
|
686 |
+
print("Initializing DepthText2Image to %s" % device)
|
687 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
688 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
689 |
+
"fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype)
|
690 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
691 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
692 |
+
torch_dtype=self.torch_dtype)
|
693 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
694 |
+
self.pipe.to(device)
|
695 |
+
self.seed = -1
|
696 |
+
self.a_prompt = 'best quality, extremely detailed'
|
697 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
698 |
+
' fewer digits, cropped, worst quality, low quality'
|
699 |
+
|
700 |
+
@prompts(name="Generate Image Condition On Depth",
|
701 |
+
description="useful when you want to generate a new real image from both the user desciption and depth image. "
|
702 |
+
"like: generate a real image of a object or something from this depth image, "
|
703 |
+
"or generate a new real image of a object or something from the depth map. "
|
704 |
+
"The input to this tool should be a comma seperated string of two, "
|
705 |
+
"representing the image_path and the user description")
|
706 |
+
def inference(self, inputs):
|
707 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
708 |
+
image = Image.open(image_path)
|
709 |
+
self.seed = random.randint(0, 65535)
|
710 |
+
seed_everything(self.seed)
|
711 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
712 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
713 |
+
guidance_scale=9.0).images[0]
|
714 |
+
updated_image_path = get_new_image_name(image_path, func_name="depth2image")
|
715 |
+
image.save(updated_image_path)
|
716 |
+
print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, "
|
717 |
+
f"Output Image: {updated_image_path}")
|
718 |
+
return updated_image_path
|
719 |
+
|
720 |
+
|
721 |
+
class Image2Normal:
|
722 |
+
def __init__(self, device):
|
723 |
+
print("Initializing Image2Normal")
|
724 |
+
self.depth_estimator = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas")
|
725 |
+
self.bg_threhold = 0.4
|
726 |
+
|
727 |
+
@prompts(name="Predict Normal Map On Image",
|
728 |
+
description="useful when you want to detect norm map of the image. "
|
729 |
+
"like: generate normal map from this image, or predict normal map of this image. "
|
730 |
+
"The input to this tool should be a string, representing the image_path")
|
731 |
+
def inference(self, inputs):
|
732 |
+
image = Image.open(inputs)
|
733 |
+
original_size = image.size
|
734 |
+
image = self.depth_estimator(image)['predicted_depth'][0]
|
735 |
+
image = image.numpy()
|
736 |
+
image_depth = image.copy()
|
737 |
+
image_depth -= np.min(image_depth)
|
738 |
+
image_depth /= np.max(image_depth)
|
739 |
+
x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
|
740 |
+
x[image_depth < self.bg_threhold] = 0
|
741 |
+
y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
|
742 |
+
y[image_depth < self.bg_threhold] = 0
|
743 |
+
z = np.ones_like(x) * np.pi * 2.0
|
744 |
+
image = np.stack([x, y, z], axis=2)
|
745 |
+
image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
|
746 |
+
image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
747 |
+
image = Image.fromarray(image)
|
748 |
+
image = image.resize(original_size)
|
749 |
+
updated_image_path = get_new_image_name(inputs, func_name="normal-map")
|
750 |
+
image.save(updated_image_path)
|
751 |
+
print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}")
|
752 |
+
return updated_image_path
|
753 |
+
|
754 |
+
|
755 |
+
class NormalText2Image:
|
756 |
+
def __init__(self, device):
|
757 |
+
print("Initializing NormalText2Image to %s" % device)
|
758 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
759 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
760 |
+
"fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype)
|
761 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
762 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
|
763 |
+
torch_dtype=self.torch_dtype)
|
764 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
765 |
+
self.pipe.to(device)
|
766 |
+
self.seed = -1
|
767 |
+
self.a_prompt = 'best quality, extremely detailed'
|
768 |
+
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
|
769 |
+
' fewer digits, cropped, worst quality, low quality'
|
770 |
+
|
771 |
+
@prompts(name="Generate Image Condition On Normal Map",
|
772 |
+
description="useful when you want to generate a new real image from both the user desciption and normal map. "
|
773 |
+
"like: generate a real image of a object or something from this normal map, "
|
774 |
+
"or generate a new real image of a object or something from the normal map. "
|
775 |
+
"The input to this tool should be a comma seperated string of two, "
|
776 |
+
"representing the image_path and the user description")
|
777 |
+
def inference(self, inputs):
|
778 |
+
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
779 |
+
image = Image.open(image_path)
|
780 |
+
self.seed = random.randint(0, 65535)
|
781 |
+
seed_everything(self.seed)
|
782 |
+
prompt = instruct_text + ', ' + self.a_prompt
|
783 |
+
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
|
784 |
+
guidance_scale=9.0).images[0]
|
785 |
+
updated_image_path = get_new_image_name(image_path, func_name="normal2image")
|
786 |
+
image.save(updated_image_path)
|
787 |
+
print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, "
|
788 |
+
f"Output Image: {updated_image_path}")
|
789 |
+
return updated_image_path
|
790 |
+
|
791 |
+
|
792 |
+
class VisualQuestionAnswering:
|
793 |
+
def __init__(self, device):
|
794 |
+
print("Initializing VisualQuestionAnswering to %s" % device)
|
795 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
796 |
+
self.device = device
|
797 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
798 |
+
self.model = BlipForQuestionAnswering.from_pretrained(
|
799 |
+
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
|
800 |
+
|
801 |
+
@prompts(name="Answer Question About The Image",
|
802 |
+
description="useful when you need an answer for a question based on an image. "
|
803 |
+
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
804 |
+
"The input to this tool should be a comma seperated string of two, representing the image_path and the question")
|
805 |
+
def inference(self, inputs):
|
806 |
+
image_path, question = inputs.split(",")
|
807 |
+
raw_image = Image.open(image_path).convert('RGB')
|
808 |
+
inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
|
809 |
+
out = self.model.generate(**inputs)
|
810 |
+
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
811 |
+
print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
|
812 |
+
f"Output Answer: {answer}")
|
813 |
+
return answer
|
814 |
+
|
815 |
+
|
816 |
+
class ConversationBot:
|
817 |
+
def __init__(self, load_dict):
|
818 |
+
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
819 |
+
print(f"Initializing VisualChatGPT, load_dict={load_dict}")
|
820 |
+
if 'ImageCaptioning' not in load_dict:
|
821 |
+
raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT")
|
822 |
+
|
823 |
+
self.llm = OpenAI(temperature=0)
|
824 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
825 |
+
|
826 |
+
self.models = dict()
|
827 |
+
for class_name, device in load_dict.items():
|
828 |
+
self.models[class_name] = globals()[class_name](device=device)
|
829 |
+
|
830 |
+
self.tools = []
|
831 |
+
for class_name, instance in self.models.items():
|
832 |
+
for e in dir(instance):
|
833 |
+
if e.startswith('inference'):
|
834 |
+
func = getattr(instance, e)
|
835 |
+
self.tools.append(Tool(name=func.name, description=func.description, func=func))
|
836 |
+
|
837 |
+
self.agent = initialize_agent(
|
838 |
+
self.tools,
|
839 |
+
self.llm,
|
840 |
+
agent="conversational-react-description",
|
841 |
+
verbose=True,
|
842 |
+
memory=self.memory,
|
843 |
+
return_intermediate_steps=True,
|
844 |
+
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
|
845 |
+
'suffix': VISUAL_CHATGPT_SUFFIX}, )
|
846 |
+
|
847 |
+
def run_text(self, text, state):
|
848 |
+
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
849 |
+
res = self.agent({"input": text})
|
850 |
+
res['output'] = res['output'].replace("\\", "/")
|
851 |
+
response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
|
852 |
+
state = state + [(text, response)]
|
853 |
+
print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
|
854 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
855 |
+
return state, state
|
856 |
+
|
857 |
+
def run_image(self, image, state, txt):
|
858 |
+
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
859 |
+
print("======>Auto Resize Image...")
|
860 |
+
img = Image.open(image.name)
|
861 |
+
width, height = img.size
|
862 |
+
ratio = min(512 / width, 512 / height)
|
863 |
+
width_new, height_new = (round(width * ratio), round(height * ratio))
|
864 |
+
width_new = int(np.round(width_new / 64.0)) * 64
|
865 |
+
height_new = int(np.round(height_new / 64.0)) * 64
|
866 |
+
img = img.resize((width_new, height_new))
|
867 |
+
img = img.convert('RGB')
|
868 |
+
img.save(image_filename, "PNG")
|
869 |
+
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
870 |
+
description = self.models['ImageCaptioning'].inference(image_filename)
|
871 |
+
Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. " \
|
872 |
+
"This information helps you to understand this image, " \
|
873 |
+
"but you should use tools to finish following tasks, " \
|
874 |
+
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(
|
875 |
+
image_filename, description)
|
876 |
+
AI_prompt = "Received. "
|
877 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
878 |
+
state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
|
879 |
+
print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
|
880 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
881 |
+
return state, state, txt + ' ' + image_filename + ' '
|
882 |
+
|
883 |
+
|
884 |
+
if __name__ == '__main__':
|
885 |
+
parser = argparse.ArgumentParser()
|
886 |
+
parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
|
887 |
+
args = parser.parse_args()
|
888 |
+
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
|
889 |
+
bot = ConversationBot(load_dict=load_dict)
|
890 |
+
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
891 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
|
892 |
+
state = gr.State([])
|
893 |
+
with gr.Row():
|
894 |
+
with gr.Column(scale=0.7):
|
895 |
+
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
|
896 |
+
container=False)
|
897 |
+
with gr.Column(scale=0.15, min_width=0):
|
898 |
+
clear = gr.Button("Clear")
|
899 |
+
with gr.Column(scale=0.15, min_width=0):
|
900 |
+
btn = gr.UploadButton("Upload", file_types=["image"])
|
901 |
+
|
902 |
+
txt.submit(bot.run_text, [txt, state], [chatbot, state])
|
903 |
+
txt.submit(lambda: "", None, txt)
|
904 |
+
btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
|
905 |
+
clear.click(bot.memory.clear)
|
906 |
+
clear.click(lambda: [], None, chatbot)
|
907 |
+
clear.click(lambda: [], None, state)
|
908 |
+
demo.launch(server_name="0.0.0.0", server_port=7868)
|
visual_chatgpt_zh.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
import re
|
7 |
+
import uuid
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
from langchain.agents.initialize import initialize_agent
|
13 |
+
from langchain.agents.tools import Tool
|
14 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
15 |
+
from langchain.llms.openai import OpenAI
|
16 |
+
|
17 |
+
from modules.image_captioning import ImageCaptioning
|
18 |
+
from modules.image_editing import ImageEditing
|
19 |
+
from modules.instruct_px2pix import InstructPix2Pix
|
20 |
+
from modules.mask_former import MaskFormer
|
21 |
+
from modules.text2img import Text2Image
|
22 |
+
from modules.visual_question_answering import VisualQuestionAnswering
|
23 |
+
from modules.controlnet_canny import Image2Canny,CannyText2Image
|
24 |
+
from modules.controlnet_depth import Image2Depth,DepthText2Image
|
25 |
+
from modules.controlnet_hed import Image2Hed,HedText2Image
|
26 |
+
from modules.controlnet_line import Image2Line,LineText2Image
|
27 |
+
from modules.controlnet_normal import Image2Normal,NormalText2Image
|
28 |
+
from modules.controlnet_pose import Image2Pose,PoseText2Image
|
29 |
+
from modules.controlnet_scibble import Image2Scribble,ScribbleText2Image
|
30 |
+
from modules.controlnet_seg import Image2Seg,SegText2Image
|
31 |
+
|
32 |
+
from modules.utils import *
|
33 |
+
|
34 |
+
import argparse
|
35 |
+
|
36 |
+
# chatgpt前缀
|
37 |
+
VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
38 |
+
Visual ChatGPT is able to process and understand large amounts of text and image. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
|
39 |
+
Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
|
40 |
+
Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
|
41 |
+
TOOLS:
|
42 |
+
------
|
43 |
+
Visual ChatGPT has access to the following tools:"""
|
44 |
+
|
45 |
+
# 调教chatgpt的instruction
|
46 |
+
VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
|
47 |
+
```
|
48 |
+
Thought: Do I need to use a tool? Yes
|
49 |
+
Action: the action to take, should be one of [{tool_names}]
|
50 |
+
Action Input: the input to the action
|
51 |
+
Observation: the result of the action
|
52 |
+
```
|
53 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
54 |
+
```
|
55 |
+
Thought: Do I need to use a tool? No
|
56 |
+
{ai_prefix}: [your response here]
|
57 |
+
```
|
58 |
+
"""
|
59 |
+
|
60 |
+
# chatgpt后缀
|
61 |
+
VISUAL_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if not exists.
|
62 |
+
You will remember to provide the image file name loyally if it's provided in the last tool observation.
|
63 |
+
Begin!
|
64 |
+
Previous conversation history:
|
65 |
+
{chat_history}
|
66 |
+
New input: {input}
|
67 |
+
Since Visual ChatGPT is a text language model, Visual ChatGPT must use tools to observe images rather than imagination.
|
68 |
+
The thoughts and observations are only visible for Visual ChatGPT, Visual ChatGPT should remember to repeat important information in the final response for Human.
|
69 |
+
Thought: Do I need to use a tool? {agent_scratchpad}"""
|
70 |
+
|
71 |
+
os.makedirs('image', exist_ok=True)
|
72 |
+
|
73 |
+
|
74 |
+
class ConversationBot:
|
75 |
+
def __init__(self, load_dict, pretrained_model_dir):
|
76 |
+
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
77 |
+
print(f"Initializing VisualChatGPT, load_dict={load_dict}")
|
78 |
+
if 'ImageCaptioning' not in load_dict:
|
79 |
+
raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT")
|
80 |
+
|
81 |
+
self.llm = OpenAI(temperature=0)
|
82 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
83 |
+
|
84 |
+
self.models = dict()
|
85 |
+
for class_name, device in load_dict.items():
|
86 |
+
self.models[class_name] = globals()[class_name](device=device, pretrained_model_dir=pretrained_model_dir)
|
87 |
+
|
88 |
+
self.tools = []
|
89 |
+
for class_name, instance in self.models.items():
|
90 |
+
for e in dir(instance):
|
91 |
+
if e.startswith('inference'):
|
92 |
+
func = getattr(instance, e)
|
93 |
+
self.tools.append(Tool(name=func.name, description=func.description, func=func))
|
94 |
+
|
95 |
+
self.agent = initialize_agent(
|
96 |
+
self.tools,
|
97 |
+
self.llm,
|
98 |
+
agent="conversational-react-description",
|
99 |
+
verbose=True,
|
100 |
+
memory=self.memory,
|
101 |
+
return_intermediate_steps=True,
|
102 |
+
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
|
103 |
+
'suffix': VISUAL_CHATGPT_SUFFIX}, )
|
104 |
+
|
105 |
+
def run_text(self, text, state):
|
106 |
+
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
107 |
+
res = self.agent({"input": text})
|
108 |
+
res['output'] = res['output'].replace("\\", "/")
|
109 |
+
response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
|
110 |
+
state = state + [(text, response)]
|
111 |
+
print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
|
112 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
113 |
+
return state, state
|
114 |
+
|
115 |
+
def run_image(self, image, state, txt):
|
116 |
+
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
117 |
+
print("======>Auto Resize Image...")
|
118 |
+
img = Image.open(image.name)
|
119 |
+
width, height = img.size
|
120 |
+
ratio = min(512 / width, 512 / height)
|
121 |
+
width_new, height_new = (round(width * ratio), round(height * ratio))
|
122 |
+
width_new = int(np.round(width_new / 64.0)) * 64
|
123 |
+
height_new = int(np.round(height_new / 64.0)) * 64
|
124 |
+
img = img.resize((width_new, height_new))
|
125 |
+
img = img.convert('RGB')
|
126 |
+
img.save(image_filename, "PNG")
|
127 |
+
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
128 |
+
description = self.models['ImageCaptioning'].inference(image_filename)
|
129 |
+
Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. " \
|
130 |
+
"This information helps you to understand this image, " \
|
131 |
+
"but you should use tools to finish following tasks, " \
|
132 |
+
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(
|
133 |
+
image_filename, description)
|
134 |
+
AI_prompt = "Received. "
|
135 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
136 |
+
state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
|
137 |
+
print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
|
138 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
139 |
+
return state, state, txt + ' ' + image_filename + ' '
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == '__main__':
|
143 |
+
parser = argparse.ArgumentParser()
|
144 |
+
parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
|
145 |
+
parser.add_argument("--pretrained_model_dir", default="./hf_models_path",
|
146 |
+
type=str, help="huggingface下载好的模型路径")
|
147 |
+
args = parser.parse_args()
|
148 |
+
|
149 |
+
pretrained_model_dir = args.pretrained_model_dir
|
150 |
+
|
151 |
+
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
|
152 |
+
bot = ConversationBot(load_dict=load_dict, pretrained_model_dir=pretrained_model_dir)
|
153 |
+
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
154 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
|
155 |
+
state = gr.State([])
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column(scale=0.7):
|
158 |
+
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
|
159 |
+
container=False)
|
160 |
+
with gr.Column(scale=0.15, min_width=0):
|
161 |
+
clear = gr.Button("Clear")
|
162 |
+
with gr.Column(scale=0.15, min_width=0):
|
163 |
+
btn = gr.UploadButton("Upload", file_types=["image"])
|
164 |
+
|
165 |
+
txt.submit(bot.run_text, [txt, state], [chatbot, state])
|
166 |
+
txt.submit(lambda: "", None, txt)
|
167 |
+
btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
|
168 |
+
clear.click(bot.memory.clear)
|
169 |
+
clear.click(lambda: [], None, chatbot)
|
170 |
+
clear.click(lambda: [], None, state)
|
171 |
+
demo.launch(server_name="0.0.0.0", server_port=7868)
|