FrankZxShen commited on
Commit
aa69275
·
1 Parent(s): e724d71

Upload 55 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/demo_short.gif filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/figure.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ image/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ pip-wheel-metadata/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # pipenv
89
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
91
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
92
+ # install all needed dependencies.
93
+ #Pipfile.lock
94
+
95
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .env
107
+ .venv
108
+ env/
109
+ venv/
110
+ ENV/
111
+ env.bak/
112
+ venv.bak/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+ annotator
132
+ cldm
133
+ ldm
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,91 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # visual-chatgpt-zh-vits
2
+ visual-chatgpt支持中文的windows版本
3
+
4
+ 融合vits推断模块
5
+
6
+
7
+ 官方论文: [<font size=5>Visual ChatGPT: Talking, Drawing and Editing with Visual Foundation Models</font>](https://arxiv.org/abs/2303.04671)
8
+
9
+ 官方仓库:[visual-chatgpt](https://github.com/microsoft/visual-chatgpt)
10
+
11
+ fork from:[visual-chatgpt-zh](https://github.com/wxj630/visual-chatgpt-zh)
12
+
13
+
14
+ ## Demo
15
+ <img src="./assets/demo_short.gif" width="750">
16
+
17
+ ## System Architecture
18
+
19
+
20
+ <p align="center"><img src="./assets/figure.jpg" alt="Logo"></p>
21
+
22
+
23
+ ## Quick Start
24
+
25
+ ```
26
+ # 1、下载代码
27
+ git clone https://github.com/FrankZxShen/visual-chatgpt-zh-vits.git
28
+
29
+ # 2、进入项目目录
30
+ cd visual-chatgpt-zh-vits
31
+
32
+ # 3、创建python环境并激活环境
33
+ conda create -n visgpt python=3.8
34
+ activate visgpt
35
+
36
+ # 4、安装环境依赖
37
+ pip install -r requirement.txt
38
+
39
+ # 5、确认api key
40
+ export OPENAI_API_KEY={Your_Private_Openai_Key}
41
+ # windows系统用set命令而不是export
42
+ set OPENAI_API_KEY={Your_Private_Openai_Key}
43
+
44
+ # 6、下载hf模型到指定目录
45
+ # 具体模型文件地址于hf_models
46
+ # 若需要vits推断功能将G.pth config.json放于vits_models下(目前仅支持日语?)
47
+ # Windows:下载pyopenjtalk Windows于text下
48
+
49
+ # 7、启动系统,这个例子我们加载了ImageCaptioning和Text2Image两个模型,
50
+ python visual_chatgpt_zh_vits.py
51
+ # 想要用哪个功能就可增加一些模型加载
52
+ python visual_chatgpt_zh_vits.py
53
+ --load ImageCaptioning_cuda:0,Text2Image_cuda:0 \
54
+ --pretrained_model_dir {your_hf_models_path} \
55
+
56
+ # 8、可以直接在visual_chatgpt_zh_vits.py 38行修改key 若需要vits 39行设定True
57
+ ```
58
+
59
+ 原作者:根据官方建议,不同显卡可以指定不同“--load”参数,显存不够的就可以时间换空间,把不重要的模型加载到cpu上,虽然推理慢但是好歹能跑不是?(手动狗头):
60
+ ```
61
+ # Advice for CPU Users
62
+ python visual_chatgpt.py --load ImageCaptioning_cpu,Text2Image_cpu
63
+
64
+ # Advice for 1 Tesla T4 15GB (Google Colab)
65
+ python visual_chatgpt.py --load "ImageCaptioning_cuda:0,Text2Image_cuda:0"
66
+
67
+ # Advice for 4 Tesla V100 32GB
68
+ python visual_chatgpt.py --load "ImageCaptioning_cuda:0,ImageEditing_cuda:0,
69
+ Text2Image_cuda:1,Image2Canny_cpu,CannyText2Image_cuda:1,
70
+ Image2Depth_cpu,DepthText2Image_cuda:1,VisualQuestionAnswering_cuda:2,
71
+ InstructPix2Pix_cuda:2,Image2Scribble_cpu,ScribbleText2Image_cuda:2,
72
+ Image2Seg_cpu,SegText2Image_cuda:2,Image2Pose_cpu,PoseText2Image_cuda:2,
73
+ Image2Hed_cpu,HedText2Image_cuda:3,Image2Normal_cpu,
74
+ NormalText2Image_cuda:3,Image2Line_cpu,LineText2Image_cuda:3"
75
+ ```
76
+
77
+ 实测环境 Windows RTX3070 8G:若只需要ImageCaptioning和Text2Image两个模型的功能,对显存要求极低,理论上能跑AI绘图均可以(>4G,但速度很慢)?
78
+
79
+ ## limitations
80
+
81
+ img无法显示在gradio上?
82
+
83
+ ## Acknowledgement
84
+
85
+ We appreciate the open source of the following projects:
86
+
87
+ - HuggingFace [[Project]](https://github.com/huggingface/transformers)
88
+
89
+ - ControlNet [[Paper]](https://arxiv.org/abs/2302.05543) [[Project]](https://github.com/lllyasviel/ControlNet)
90
+
91
+ - Stable Diffusion [[Paper]](https://arxiv.org/abs/2112.10752) [[Project]](https://github.com/CompVis/stable-diffusion)
assets/demo.gif ADDED

Git LFS Details

  • SHA256: aef60fb82c9d46584c298346c9e5579cc51df0b6f347bccdb01f8476af9535e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.08 MB
assets/demo_short.gif ADDED

Git LFS Details

  • SHA256: a4a388707ffe492d7c884a54e1ae84cd30ccd4e9b97f5319478cbb87dc87de3c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
assets/figure.jpg ADDED

Git LFS Details

  • SHA256: 2f369af4e9bfec6d524395650ef4481c9cc13f12f3897fa923f859cf925338c0
  • Pointer size: 132 Bytes
  • Size of remote file: 3.63 MB
attentions.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ from vits_modules import LayerNorm
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
12
+ super().__init__()
13
+ self.hidden_channels = hidden_channels
14
+ self.filter_channels = filter_channels
15
+ self.n_heads = n_heads
16
+ self.n_layers = n_layers
17
+ self.kernel_size = kernel_size
18
+ self.p_dropout = p_dropout
19
+ self.window_size = window_size
20
+
21
+ self.drop = nn.Dropout(p_dropout)
22
+ self.attn_layers = nn.ModuleList()
23
+ self.norm_layers_1 = nn.ModuleList()
24
+ self.ffn_layers = nn.ModuleList()
25
+ self.norm_layers_2 = nn.ModuleList()
26
+ for i in range(self.n_layers):
27
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
28
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
29
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
30
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
31
+
32
+ def forward(self, x, x_mask):
33
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
34
+ x = x * x_mask
35
+ for i in range(self.n_layers):
36
+ y = self.attn_layers[i](x, x, attn_mask)
37
+ y = self.drop(y)
38
+ x = self.norm_layers_1[i](x + y)
39
+
40
+ y = self.ffn_layers[i](x, x_mask)
41
+ y = self.drop(y)
42
+ x = self.norm_layers_2[i](x + y)
43
+ x = x * x_mask
44
+ return x
45
+
46
+
47
+ class Decoder(nn.Module):
48
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
49
+ super().__init__()
50
+ self.hidden_channels = hidden_channels
51
+ self.filter_channels = filter_channels
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.kernel_size = kernel_size
55
+ self.p_dropout = p_dropout
56
+ self.proximal_bias = proximal_bias
57
+ self.proximal_init = proximal_init
58
+
59
+ self.drop = nn.Dropout(p_dropout)
60
+ self.self_attn_layers = nn.ModuleList()
61
+ self.norm_layers_0 = nn.ModuleList()
62
+ self.encdec_attn_layers = nn.ModuleList()
63
+ self.norm_layers_1 = nn.ModuleList()
64
+ self.ffn_layers = nn.ModuleList()
65
+ self.norm_layers_2 = nn.ModuleList()
66
+ for i in range(self.n_layers):
67
+ self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
68
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
69
+ self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
70
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
71
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
72
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
73
+
74
+ def forward(self, x, x_mask, h, h_mask):
75
+ """
76
+ x: decoder input
77
+ h: encoder output
78
+ """
79
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
80
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
81
+ x = x * x_mask
82
+ for i in range(self.n_layers):
83
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
84
+ y = self.drop(y)
85
+ x = self.norm_layers_0[i](x + y)
86
+
87
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
88
+ y = self.drop(y)
89
+ x = self.norm_layers_1[i](x + y)
90
+
91
+ y = self.ffn_layers[i](x, x_mask)
92
+ y = self.drop(y)
93
+ x = self.norm_layers_2[i](x + y)
94
+ x = x * x_mask
95
+ return x
96
+
97
+
98
+ class MultiHeadAttention(nn.Module):
99
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
100
+ super().__init__()
101
+ assert channels % n_heads == 0
102
+
103
+ self.channels = channels
104
+ self.out_channels = out_channels
105
+ self.n_heads = n_heads
106
+ self.p_dropout = p_dropout
107
+ self.window_size = window_size
108
+ self.heads_share = heads_share
109
+ self.block_length = block_length
110
+ self.proximal_bias = proximal_bias
111
+ self.proximal_init = proximal_init
112
+ self.attn = None
113
+
114
+ self.k_channels = channels // n_heads
115
+ self.conv_q = nn.Conv1d(channels, channels, 1)
116
+ self.conv_k = nn.Conv1d(channels, channels, 1)
117
+ self.conv_v = nn.Conv1d(channels, channels, 1)
118
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
119
+ self.drop = nn.Dropout(p_dropout)
120
+
121
+ if window_size is not None:
122
+ n_heads_rel = 1 if heads_share else n_heads
123
+ rel_stddev = self.k_channels**-0.5
124
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
126
+
127
+ nn.init.xavier_uniform_(self.conv_q.weight)
128
+ nn.init.xavier_uniform_(self.conv_k.weight)
129
+ nn.init.xavier_uniform_(self.conv_v.weight)
130
+ if proximal_init:
131
+ with torch.no_grad():
132
+ self.conv_k.weight.copy_(self.conv_q.weight)
133
+ self.conv_k.bias.copy_(self.conv_q.bias)
134
+
135
+ def forward(self, x, c, attn_mask=None):
136
+ q = self.conv_q(x)
137
+ k = self.conv_k(c)
138
+ v = self.conv_v(c)
139
+
140
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
141
+
142
+ x = self.conv_o(x)
143
+ return x
144
+
145
+ def attention(self, query, key, value, mask=None):
146
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
147
+ b, d, t_s, t_t = (*key.size(), query.size(2))
148
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
+
152
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
153
+ if self.window_size is not None:
154
+ assert t_s == t_t, "Relative attention is only available for self-attention."
155
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
+ rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
157
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
158
+ scores = scores + scores_local
159
+ if self.proximal_bias:
160
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
161
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
162
+ if mask is not None:
163
+ scores = scores.masked_fill(mask == 0, -1e4)
164
+ if self.block_length is not None:
165
+ assert t_s == t_t, "Local attention is only available for self-attention."
166
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167
+ scores = scores.masked_fill(block_mask == 0, -1e4)
168
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169
+ p_attn = self.drop(p_attn)
170
+ output = torch.matmul(p_attn, value)
171
+ if self.window_size is not None:
172
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
173
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176
+ return output, p_attn
177
+
178
+ def _matmul_with_relative_values(self, x, y):
179
+ """
180
+ x: [b, h, l, m]
181
+ y: [h or 1, m, d]
182
+ ret: [b, h, l, d]
183
+ """
184
+ ret = torch.matmul(x, y.unsqueeze(0))
185
+ return ret
186
+
187
+ def _matmul_with_relative_keys(self, x, y):
188
+ """
189
+ x: [b, h, l, d]
190
+ y: [h or 1, m, d]
191
+ ret: [b, h, l, m]
192
+ """
193
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194
+ return ret
195
+
196
+ def _get_relative_embeddings(self, relative_embeddings, length):
197
+ max_relative_position = 2 * self.window_size + 1
198
+ # Pad first before slice to avoid using cond ops.
199
+ pad_length = max(length - (self.window_size + 1), 0)
200
+ slice_start_position = max((self.window_size + 1) - length, 0)
201
+ slice_end_position = slice_start_position + 2 * length - 1
202
+ if pad_length > 0:
203
+ padded_relative_embeddings = F.pad(
204
+ relative_embeddings,
205
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206
+ else:
207
+ padded_relative_embeddings = relative_embeddings
208
+ used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
209
+ return used_relative_embeddings
210
+
211
+ def _relative_position_to_absolute_position(self, x):
212
+ """
213
+ x: [b, h, l, 2*l-1]
214
+ ret: [b, h, l, l]
215
+ """
216
+ batch, heads, length, _ = x.size()
217
+ # Concat columns of pad to shift from relative to absolute indexing.
218
+ x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
219
+
220
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
221
+ x_flat = x.view([batch, heads, length * 2 * length])
222
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
223
+
224
+ # Reshape and slice out the padded elements.
225
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
226
+ return x_final
227
+
228
+ def _absolute_position_to_relative_position(self, x):
229
+ """
230
+ x: [b, h, l, l]
231
+ ret: [b, h, l, 2*l-1]
232
+ """
233
+ batch, heads, length, _ = x.size()
234
+ # padd along column
235
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
236
+ x_flat = x.view([batch, heads, length**2 + length*(length -1)])
237
+ # add 0's in the beginning that will skew the elements after reshape
238
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
240
+ return x_final
241
+
242
+ def _attention_bias_proximal(self, length):
243
+ """Bias for self-attention to encourage attention to close positions.
244
+ Args:
245
+ length: an integer scalar.
246
+ Returns:
247
+ a Tensor with shape [1, 1, length, length]
248
+ """
249
+ r = torch.arange(length, dtype=torch.float32)
250
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
252
+
253
+
254
+ class FFN(nn.Module):
255
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
256
+ super().__init__()
257
+ self.in_channels = in_channels
258
+ self.out_channels = out_channels
259
+ self.filter_channels = filter_channels
260
+ self.kernel_size = kernel_size
261
+ self.p_dropout = p_dropout
262
+ self.activation = activation
263
+ self.causal = causal
264
+
265
+ if causal:
266
+ self.padding = self._causal_padding
267
+ else:
268
+ self.padding = self._same_padding
269
+
270
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
271
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
272
+ self.drop = nn.Dropout(p_dropout)
273
+
274
+ def forward(self, x, x_mask):
275
+ x = self.conv_1(self.padding(x * x_mask))
276
+ if self.activation == "gelu":
277
+ x = x * torch.sigmoid(1.702 * x)
278
+ else:
279
+ x = torch.relu(x)
280
+ x = self.drop(x)
281
+ x = self.conv_2(self.padding(x * x_mask))
282
+ return x * x_mask
283
+
284
+ def _causal_padding(self, x):
285
+ if self.kernel_size == 1:
286
+ return x
287
+ pad_l = self.kernel_size - 1
288
+ pad_r = 0
289
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
290
+ x = F.pad(x, commons.convert_pad_shape(padding))
291
+ return x
292
+
293
+ def _same_padding(self, x):
294
+ if self.kernel_size == 1:
295
+ return x
296
+ pad_l = (self.kernel_size - 1) // 2
297
+ pad_r = self.kernel_size // 2
298
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
299
+ x = F.pad(x, commons.convert_pad_shape(padding))
300
+ return x
commons.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ import torch.jit
4
+
5
+
6
+ def script_method(fn, _rcb=None):
7
+ return fn
8
+
9
+
10
+ def script(obj, optimize=True, _frames_up=0, _rcb=None):
11
+ return obj
12
+
13
+
14
+ torch.jit.script_method = script_method
15
+ torch.jit.script = script
16
+
17
+
18
+ def init_weights(m, mean=0.0, std=0.01):
19
+ classname = m.__class__.__name__
20
+ if classname.find("Conv") != -1:
21
+ m.weight.data.normal_(mean, std)
22
+
23
+
24
+ def get_padding(kernel_size, dilation=1):
25
+ return int((kernel_size*dilation - dilation)/2)
26
+
27
+
28
+ def intersperse(lst, item):
29
+ result = [item] * (len(lst) * 2 + 1)
30
+ result[1::2] = lst
31
+ return result
32
+
33
+
34
+ def slice_segments(x, ids_str, segment_size=4):
35
+ ret = torch.zeros_like(x[:, :, :segment_size])
36
+ for i in range(x.size(0)):
37
+ idx_str = ids_str[i]
38
+ idx_end = idx_str + segment_size
39
+ ret[i] = x[i, :, idx_str:idx_end]
40
+ return ret
41
+
42
+
43
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
44
+ b, d, t = x.size()
45
+ if x_lengths is None:
46
+ x_lengths = t
47
+ ids_str_max = x_lengths - segment_size + 1
48
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
49
+ ret = slice_segments(x, ids_str, segment_size)
50
+ return ret, ids_str
51
+
52
+
53
+ def subsequent_mask(length):
54
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
55
+ return mask
56
+
57
+
58
+ @torch.jit.script
59
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
60
+ n_channels_int = n_channels[0]
61
+ in_act = input_a + input_b
62
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
63
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
64
+ acts = t_act * s_act
65
+ return acts
66
+
67
+
68
+ def convert_pad_shape(pad_shape):
69
+ l = pad_shape[::-1]
70
+ pad_shape = [item for sublist in l for item in sublist]
71
+ return pad_shape
72
+
73
+
74
+ def sequence_mask(length, max_length=None):
75
+ if max_length is None:
76
+ max_length = length.max()
77
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
78
+ return x.unsqueeze(0) < length.unsqueeze(1)
79
+
80
+
81
+ def generate_path(duration, mask):
82
+ """
83
+ duration: [b, 1, t_x]
84
+ mask: [b, 1, t_y, t_x]
85
+ """
86
+ device = duration.device
87
+
88
+ b, _, t_y, t_x = mask.shape
89
+ cum_duration = torch.cumsum(duration, -1)
90
+
91
+ cum_duration_flat = cum_duration.view(b * t_x)
92
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
93
+ path = path.view(b, t_x, t_y)
94
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
95
+ path = path.unsqueeze(1).transpose(2,3) * mask
96
+ return path
hf_models/download.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git clone https://huggingface.co/Salesforce/blip-image-captioning-base
2
+
3
+ git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
4
+
5
+ git clone https://huggingface.co/runwayml/stable-diffusion-inpainting
6
+
7
+ git clone https://huggingface.co/CIDAS/clipseg-rd64-refined
8
+
9
+ git clone https://huggingface.co/timbrooks/instruct-pix2pix
10
+
11
+ git clone https://huggingface.co/Salesforce/blip-vqa-base
12
+
13
+ git clone https://huggingface.co/lllyasviel/ControlNet
14
+
15
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-canny
16
+
17
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-seg
18
+
19
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-scribble
20
+
21
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-normal
22
+
23
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-mlsd
24
+
25
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-depth
26
+
27
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-hed
28
+
29
+ git clone https://huggingface.co/lllyasviel/sd-controlnet-openpose
30
+
31
+ git clone https://huggingface.co/openmmlab/upernet-convnext-small
32
+
33
+ git clone https://huggingface.co/Intel/dpt-hybrid-midas
hubert_model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, Tuple
3
+ import random
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
9
+
10
+ class Hubert(nn.Module):
11
+ def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
12
+ super().__init__()
13
+ self._mask = mask
14
+ self.feature_extractor = FeatureExtractor()
15
+ self.feature_projection = FeatureProjection()
16
+ self.positional_embedding = PositionalConvEmbedding()
17
+ self.norm = nn.LayerNorm(768)
18
+ self.dropout = nn.Dropout(0.1)
19
+ self.encoder = TransformerEncoder(
20
+ nn.TransformerEncoderLayer(
21
+ 768, 12, 3072, activation="gelu", batch_first=True
22
+ ),
23
+ 12,
24
+ )
25
+ self.proj = nn.Linear(768, 256)
26
+
27
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
28
+ self.label_embedding = nn.Embedding(num_label_embeddings, 256)
29
+
30
+ def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
31
+ mask = None
32
+ if self.training and self._mask:
33
+ mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
34
+ x[mask] = self.masked_spec_embed.to(x.dtype)
35
+ return x, mask
36
+
37
+ def encode(
38
+ self, x: torch.Tensor, layer: Optional[int] = None
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ x = self.feature_extractor(x)
41
+ x = self.feature_projection(x.transpose(1, 2))
42
+ x, mask = self.mask(x)
43
+ x = x + self.positional_embedding(x)
44
+ x = self.dropout(self.norm(x))
45
+ x = self.encoder(x, output_layer=layer)
46
+ return x, mask
47
+
48
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
49
+ logits = torch.cosine_similarity(
50
+ x.unsqueeze(2),
51
+ self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
52
+ dim=-1,
53
+ )
54
+ return logits / 0.1
55
+
56
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ x, mask = self.encode(x)
58
+ x = self.proj(x)
59
+ logits = self.logits(x)
60
+ return logits, mask
61
+
62
+
63
+ class HubertSoft(Hubert):
64
+ def __init__(self):
65
+ super().__init__()
66
+
67
+ @torch.inference_mode()
68
+ def units(self, wav: torch.Tensor) -> torch.Tensor:
69
+ wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
70
+ x, _ = self.encode(wav)
71
+ return self.proj(x)
72
+
73
+
74
+ class FeatureExtractor(nn.Module):
75
+ def __init__(self):
76
+ super().__init__()
77
+ self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
78
+ self.norm0 = nn.GroupNorm(512, 512)
79
+ self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
80
+ self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
81
+ self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
82
+ self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
83
+ self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
84
+ self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ x = F.gelu(self.norm0(self.conv0(x)))
88
+ x = F.gelu(self.conv1(x))
89
+ x = F.gelu(self.conv2(x))
90
+ x = F.gelu(self.conv3(x))
91
+ x = F.gelu(self.conv4(x))
92
+ x = F.gelu(self.conv5(x))
93
+ x = F.gelu(self.conv6(x))
94
+ return x
95
+
96
+
97
+ class FeatureProjection(nn.Module):
98
+ def __init__(self):
99
+ super().__init__()
100
+ self.norm = nn.LayerNorm(512)
101
+ self.projection = nn.Linear(512, 768)
102
+ self.dropout = nn.Dropout(0.1)
103
+
104
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
105
+ x = self.norm(x)
106
+ x = self.projection(x)
107
+ x = self.dropout(x)
108
+ return x
109
+
110
+
111
+ class PositionalConvEmbedding(nn.Module):
112
+ def __init__(self):
113
+ super().__init__()
114
+ self.conv = nn.Conv1d(
115
+ 768,
116
+ 768,
117
+ kernel_size=128,
118
+ padding=128 // 2,
119
+ groups=16,
120
+ )
121
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ x = self.conv(x.transpose(1, 2))
125
+ x = F.gelu(x[:, :, :-1])
126
+ return x.transpose(1, 2)
127
+
128
+
129
+ class TransformerEncoder(nn.Module):
130
+ def __init__(
131
+ self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
132
+ ) -> None:
133
+ super(TransformerEncoder, self).__init__()
134
+ self.layers = nn.ModuleList(
135
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
136
+ )
137
+ self.num_layers = num_layers
138
+
139
+ def forward(
140
+ self,
141
+ src: torch.Tensor,
142
+ mask: torch.Tensor = None,
143
+ src_key_padding_mask: torch.Tensor = None,
144
+ output_layer: Optional[int] = None,
145
+ ) -> torch.Tensor:
146
+ output = src
147
+ for layer in self.layers[:output_layer]:
148
+ output = layer(
149
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
150
+ )
151
+ return output
152
+
153
+
154
+ def _compute_mask(
155
+ shape: Tuple[int, int],
156
+ mask_prob: float,
157
+ mask_length: int,
158
+ device: torch.device,
159
+ min_masks: int = 0,
160
+ ) -> torch.Tensor:
161
+ batch_size, sequence_length = shape
162
+
163
+ if mask_length < 1:
164
+ raise ValueError("`mask_length` has to be bigger than 0.")
165
+
166
+ if mask_length > sequence_length:
167
+ raise ValueError(
168
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
169
+ )
170
+
171
+ # compute number of masked spans in batch
172
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
173
+ num_masked_spans = max(num_masked_spans, min_masks)
174
+
175
+ # make sure num masked indices <= sequence_length
176
+ if num_masked_spans * mask_length > sequence_length:
177
+ num_masked_spans = sequence_length // mask_length
178
+
179
+ # SpecAugment mask to fill
180
+ mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
181
+
182
+ # uniform distribution to sample from, make sure that offset samples are < sequence_length
183
+ uniform_dist = torch.ones(
184
+ (batch_size, sequence_length - (mask_length - 1)), device=device
185
+ )
186
+
187
+ # get random indices to mask
188
+ mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
189
+
190
+ # expand masked indices to masked spans
191
+ mask_indices = (
192
+ mask_indices.unsqueeze(dim=-1)
193
+ .expand((batch_size, num_masked_spans, mask_length))
194
+ .reshape(batch_size, num_masked_spans * mask_length)
195
+ )
196
+ offsets = (
197
+ torch.arange(mask_length, device=device)[None, None, :]
198
+ .expand((batch_size, num_masked_spans, mask_length))
199
+ .reshape(batch_size, num_masked_spans * mask_length)
200
+ )
201
+ mask_idxs = mask_indices + offsets
202
+
203
+ # scatter indices to mask
204
+ mask = mask.scatter(1, mask_idxs, True)
205
+
206
+ return mask
207
+
208
+
209
+ def hubert_soft(
210
+ path: str
211
+ ) -> HubertSoft:
212
+ r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
213
+ Args:
214
+ path (str): path of a pretrained model
215
+ """
216
+ hubert = HubertSoft()
217
+ checkpoint = torch.load(path)
218
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
219
+ hubert.load_state_dict(checkpoint)
220
+ hubert.eval()
221
+ return hubert
image/01cc9083.png ADDED
image/1d988a81.png ADDED
image/307ade76.png ADDED
image/5ebacb6a.png ADDED
image/cdb4d5e5.png ADDED
mel_processing.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+
8
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9
+ """
10
+ PARAMS
11
+ ------
12
+ C: compression factor
13
+ """
14
+ return torch.log(torch.clamp(x, min=clip_val) * C)
15
+
16
+
17
+ def dynamic_range_decompression_torch(x, C=1):
18
+ """
19
+ PARAMS
20
+ ------
21
+ C: compression factor used to compress
22
+ """
23
+ return torch.exp(x) / C
24
+
25
+
26
+ def spectral_normalize_torch(magnitudes):
27
+ output = dynamic_range_compression_torch(magnitudes)
28
+ return output
29
+
30
+
31
+ def spectral_de_normalize_torch(magnitudes):
32
+ output = dynamic_range_decompression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ mel_basis = {}
37
+ hann_window = {}
38
+
39
+
40
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41
+ if torch.min(y) < -1.:
42
+ print('min value is ', torch.min(y))
43
+ if torch.max(y) > 1.:
44
+ print('max value is ', torch.max(y))
45
+
46
+ global hann_window
47
+ dtype_device = str(y.dtype) + '_' + str(y.device)
48
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
49
+ if wnsize_dtype_device not in hann_window:
50
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
51
+
52
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
53
+ y = y.squeeze(1)
54
+
55
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
56
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
57
+
58
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
59
+ return spec
60
+
61
+
62
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
63
+ global mel_basis
64
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
65
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
66
+ if fmax_dtype_device not in mel_basis:
67
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
68
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
69
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
70
+ spec = spectral_normalize_torch(spec)
71
+ return spec
72
+
73
+
74
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
75
+ if torch.min(y) < -1.:
76
+ print('min value is ', torch.min(y))
77
+ if torch.max(y) > 1.:
78
+ print('max value is ', torch.max(y))
79
+
80
+ global mel_basis, hann_window
81
+ dtype_device = str(y.dtype) + '_' + str(y.device)
82
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
83
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
84
+ if fmax_dtype_device not in mel_basis:
85
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
86
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
87
+ if wnsize_dtype_device not in hann_window:
88
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
89
+
90
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
91
+ y = y.squeeze(1)
92
+
93
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
94
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
95
+
96
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
97
+
98
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
99
+ spec = spectral_normalize_torch(spec)
100
+
101
+ return spec
models.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import vits_modules as modules
8
+ import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d
11
+ from torch.nn.utils import weight_norm
12
+ from commons import init_weights
13
+
14
+
15
+ class StochasticDurationPredictor(nn.Module):
16
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
17
+ super().__init__()
18
+ filter_channels = in_channels # it needs to be removed from future version.
19
+ self.in_channels = in_channels
20
+ self.filter_channels = filter_channels
21
+ self.kernel_size = kernel_size
22
+ self.p_dropout = p_dropout
23
+ self.n_flows = n_flows
24
+ self.gin_channels = gin_channels
25
+
26
+ self.log_flow = modules.Log()
27
+ self.flows = nn.ModuleList()
28
+ self.flows.append(modules.ElementwiseAffine(2))
29
+ for i in range(n_flows):
30
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
31
+ self.flows.append(modules.Flip())
32
+
33
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
34
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
35
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
36
+ self.post_flows = nn.ModuleList()
37
+ self.post_flows.append(modules.ElementwiseAffine(2))
38
+ for i in range(4):
39
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
40
+ self.post_flows.append(modules.Flip())
41
+
42
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
43
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
44
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
45
+ if gin_channels != 0:
46
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
47
+
48
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
49
+ x = torch.detach(x)
50
+ x = self.pre(x)
51
+ if g is not None:
52
+ g = torch.detach(g)
53
+ x = x + self.cond(g)
54
+ x = self.convs(x, x_mask)
55
+ x = self.proj(x) * x_mask
56
+
57
+ if not reverse:
58
+ flows = self.flows
59
+ assert w is not None
60
+
61
+ logdet_tot_q = 0
62
+ h_w = self.post_pre(w)
63
+ h_w = self.post_convs(h_w, x_mask)
64
+ h_w = self.post_proj(h_w) * x_mask
65
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
66
+ z_q = e_q
67
+ for flow in self.post_flows:
68
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
69
+ logdet_tot_q += logdet_q
70
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
71
+ u = torch.sigmoid(z_u) * x_mask
72
+ z0 = (w - u) * x_mask
73
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
74
+ logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
75
+
76
+ logdet_tot = 0
77
+ z0, logdet = self.log_flow(z0, x_mask)
78
+ logdet_tot += logdet
79
+ z = torch.cat([z0, z1], 1)
80
+ for flow in flows:
81
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
82
+ logdet_tot = logdet_tot + logdet
83
+ nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
84
+ return nll + logq # [b]
85
+ else:
86
+ flows = list(reversed(self.flows))
87
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
88
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
89
+ for flow in flows:
90
+ z = flow(z, x_mask, g=x, reverse=reverse)
91
+ z0, z1 = torch.split(z, [1, 1], 1)
92
+ logw = z0
93
+ return logw
94
+
95
+
96
+ class DurationPredictor(nn.Module):
97
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
98
+ super().__init__()
99
+
100
+ self.in_channels = in_channels
101
+ self.filter_channels = filter_channels
102
+ self.kernel_size = kernel_size
103
+ self.p_dropout = p_dropout
104
+ self.gin_channels = gin_channels
105
+
106
+ self.drop = nn.Dropout(p_dropout)
107
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
108
+ self.norm_1 = modules.LayerNorm(filter_channels)
109
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
110
+ self.norm_2 = modules.LayerNorm(filter_channels)
111
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
112
+
113
+ if gin_channels != 0:
114
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
115
+
116
+ def forward(self, x, x_mask, g=None):
117
+ x = torch.detach(x)
118
+ if g is not None:
119
+ g = torch.detach(g)
120
+ x = x + self.cond(g)
121
+ x = self.conv_1(x * x_mask)
122
+ x = torch.relu(x)
123
+ x = self.norm_1(x)
124
+ x = self.drop(x)
125
+ x = self.conv_2(x * x_mask)
126
+ x = torch.relu(x)
127
+ x = self.norm_2(x)
128
+ x = self.drop(x)
129
+ x = self.proj(x * x_mask)
130
+ return x * x_mask
131
+
132
+
133
+ class TextEncoder(nn.Module):
134
+ def __init__(self,
135
+ n_vocab,
136
+ out_channels,
137
+ hidden_channels,
138
+ filter_channels,
139
+ n_heads,
140
+ n_layers,
141
+ kernel_size,
142
+ p_dropout,
143
+ emotion_embedding):
144
+ super().__init__()
145
+ self.n_vocab = n_vocab
146
+ self.out_channels = out_channels
147
+ self.hidden_channels = hidden_channels
148
+ self.filter_channels = filter_channels
149
+ self.n_heads = n_heads
150
+ self.n_layers = n_layers
151
+ self.kernel_size = kernel_size
152
+ self.p_dropout = p_dropout
153
+ self.emotion_embedding = emotion_embedding
154
+
155
+ if self.n_vocab!=0:
156
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
157
+ if emotion_embedding:
158
+ self.emo_proj = nn.Linear(1024, hidden_channels)
159
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
160
+
161
+ self.encoder = attentions.Encoder(
162
+ hidden_channels,
163
+ filter_channels,
164
+ n_heads,
165
+ n_layers,
166
+ kernel_size,
167
+ p_dropout)
168
+ self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
169
+
170
+ def forward(self, x, x_lengths, emotion_embedding=None):
171
+ if self.n_vocab!=0:
172
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
173
+ if emotion_embedding is not None:
174
+ x = x + self.emo_proj(emotion_embedding.unsqueeze(1))
175
+ x = torch.transpose(x, 1, -1) # [b, h, t]
176
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
177
+
178
+ x = self.encoder(x * x_mask, x_mask)
179
+ stats = self.proj(x) * x_mask
180
+
181
+ m, logs = torch.split(stats, self.out_channels, dim=1)
182
+ return x, m, logs, x_mask
183
+
184
+
185
+ class ResidualCouplingBlock(nn.Module):
186
+ def __init__(self,
187
+ channels,
188
+ hidden_channels,
189
+ kernel_size,
190
+ dilation_rate,
191
+ n_layers,
192
+ n_flows=4,
193
+ gin_channels=0):
194
+ super().__init__()
195
+ self.channels = channels
196
+ self.hidden_channels = hidden_channels
197
+ self.kernel_size = kernel_size
198
+ self.dilation_rate = dilation_rate
199
+ self.n_layers = n_layers
200
+ self.n_flows = n_flows
201
+ self.gin_channels = gin_channels
202
+
203
+ self.flows = nn.ModuleList()
204
+ for i in range(n_flows):
205
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
206
+ self.flows.append(modules.Flip())
207
+
208
+ def forward(self, x, x_mask, g=None, reverse=False):
209
+ if not reverse:
210
+ for flow in self.flows:
211
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
212
+ else:
213
+ for flow in reversed(self.flows):
214
+ x = flow(x, x_mask, g=g, reverse=reverse)
215
+ return x
216
+
217
+
218
+ class PosteriorEncoder(nn.Module):
219
+ def __init__(self,
220
+ in_channels,
221
+ out_channels,
222
+ hidden_channels,
223
+ kernel_size,
224
+ dilation_rate,
225
+ n_layers,
226
+ gin_channels=0):
227
+ super().__init__()
228
+ self.in_channels = in_channels
229
+ self.out_channels = out_channels
230
+ self.hidden_channels = hidden_channels
231
+ self.kernel_size = kernel_size
232
+ self.dilation_rate = dilation_rate
233
+ self.n_layers = n_layers
234
+ self.gin_channels = gin_channels
235
+
236
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
237
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
238
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
239
+
240
+ def forward(self, x, x_lengths, g=None):
241
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
242
+ x = self.pre(x) * x_mask
243
+ x = self.enc(x, x_mask, g=g)
244
+ stats = self.proj(x) * x_mask
245
+ m, logs = torch.split(stats, self.out_channels, dim=1)
246
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
247
+ return z, m, logs, x_mask
248
+
249
+
250
+ class Generator(torch.nn.Module):
251
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
252
+ super(Generator, self).__init__()
253
+ self.num_kernels = len(resblock_kernel_sizes)
254
+ self.num_upsamples = len(upsample_rates)
255
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
256
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
257
+
258
+ self.ups = nn.ModuleList()
259
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
260
+ self.ups.append(weight_norm(
261
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
262
+ k, u, padding=(k-u)//2)))
263
+
264
+ self.resblocks = nn.ModuleList()
265
+ for i in range(len(self.ups)):
266
+ ch = upsample_initial_channel//(2**(i+1))
267
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
268
+ self.resblocks.append(resblock(ch, k, d))
269
+
270
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
271
+ self.ups.apply(init_weights)
272
+
273
+ if gin_channels != 0:
274
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
275
+
276
+ def forward(self, x, g=None):
277
+ x = self.conv_pre(x)
278
+ if g is not None:
279
+ x = x + self.cond(g)
280
+
281
+ for i in range(self.num_upsamples):
282
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
283
+ x = self.ups[i](x)
284
+ xs = None
285
+ for j in range(self.num_kernels):
286
+ if xs is None:
287
+ xs = self.resblocks[i*self.num_kernels+j](x)
288
+ else:
289
+ xs += self.resblocks[i*self.num_kernels+j](x)
290
+ x = xs / self.num_kernels
291
+ x = F.leaky_relu(x)
292
+ x = self.conv_post(x)
293
+ x = torch.tanh(x)
294
+
295
+ return x
296
+
297
+
298
+ class SynthesizerTrn(nn.Module):
299
+ """
300
+ Synthesizer for Training
301
+ """
302
+
303
+ def __init__(self,
304
+ n_vocab,
305
+ spec_channels,
306
+ segment_size,
307
+ inter_channels,
308
+ hidden_channels,
309
+ filter_channels,
310
+ n_heads,
311
+ n_layers,
312
+ kernel_size,
313
+ p_dropout,
314
+ resblock,
315
+ resblock_kernel_sizes,
316
+ resblock_dilation_sizes,
317
+ upsample_rates,
318
+ upsample_initial_channel,
319
+ upsample_kernel_sizes,
320
+ n_speakers=0,
321
+ gin_channels=0,
322
+ use_sdp=True,
323
+ emotion_embedding=False,
324
+ **kwargs):
325
+
326
+ super().__init__()
327
+ self.n_vocab = n_vocab
328
+ self.spec_channels = spec_channels
329
+ self.inter_channels = inter_channels
330
+ self.hidden_channels = hidden_channels
331
+ self.filter_channels = filter_channels
332
+ self.n_heads = n_heads
333
+ self.n_layers = n_layers
334
+ self.kernel_size = kernel_size
335
+ self.p_dropout = p_dropout
336
+ self.resblock = resblock
337
+ self.resblock_kernel_sizes = resblock_kernel_sizes
338
+ self.resblock_dilation_sizes = resblock_dilation_sizes
339
+ self.upsample_rates = upsample_rates
340
+ self.upsample_initial_channel = upsample_initial_channel
341
+ self.upsample_kernel_sizes = upsample_kernel_sizes
342
+ self.segment_size = segment_size
343
+ self.n_speakers = n_speakers
344
+ self.gin_channels = gin_channels
345
+
346
+ self.use_sdp = use_sdp
347
+
348
+ self.enc_p = TextEncoder(n_vocab,
349
+ inter_channels,
350
+ hidden_channels,
351
+ filter_channels,
352
+ n_heads,
353
+ n_layers,
354
+ kernel_size,
355
+ p_dropout,
356
+ emotion_embedding)
357
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
358
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
359
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
360
+
361
+ if use_sdp:
362
+ self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
363
+ else:
364
+ self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
365
+
366
+ if n_speakers > 1:
367
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
368
+
369
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None):
370
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
371
+ if self.n_speakers > 0:
372
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
373
+ else:
374
+ g = None
375
+
376
+ if self.use_sdp:
377
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
378
+ else:
379
+ logw = self.dp(x, x_mask, g=g)
380
+ w = torch.exp(logw) * x_mask * length_scale
381
+ w_ceil = torch.ceil(w)
382
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
383
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
384
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
385
+ attn = commons.generate_path(w_ceil, attn_mask)
386
+
387
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
388
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
389
+
390
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
391
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
392
+ o = self.dec((z * y_mask)[:,:,:max_len], g=g)
393
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
394
+
395
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
396
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
397
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
398
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
399
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
400
+ z_p = self.flow(z, y_mask, g=g_src)
401
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
402
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
403
+ return o_hat, y_mask, (z, z_p, z_hat)
404
+
modules/__init__.py ADDED
File without changes
modules/controlnet_canny.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Canny:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Canny")
6
+ self.low_threshold = 100
7
+ self.high_threshold = 200
8
+
9
+ @prompts(name="Edge Detection On Image",
10
+ description="useful when you want to detect the edge of the image. "
11
+ "like: detect the edges of this image, or canny detection on image, "
12
+ "or perform edge detection on this image, or detect the canny image of this image. "
13
+ "The input to this tool should be a string, representing the image_path")
14
+ def inference(self, inputs):
15
+ image = Image.open(inputs)
16
+ image = np.array(image)
17
+ canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
18
+ canny = canny[:, :, None]
19
+ canny = np.concatenate([canny, canny, canny], axis=2)
20
+ canny = Image.fromarray(canny)
21
+ updated_image_path = get_new_image_name(inputs, func_name="edge")
22
+ canny.save(updated_image_path)
23
+ print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}")
24
+ return updated_image_path
25
+
26
+ class CannyText2Image:
27
+ def __init__(self, device, pretrained_model_dir):
28
+ print("Initializing CannyText2Image to %s" % device)
29
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
30
+ self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-canny",
31
+ torch_dtype=self.torch_dtype)
32
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
33
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
34
+ torch_dtype=self.torch_dtype)
35
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
36
+ self.pipe.to(device)
37
+ self.seed = -1
38
+ self.a_prompt = 'best quality, extremely detailed'
39
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
40
+ 'fewer digits, cropped, worst quality, low quality'
41
+
42
+ @prompts(name="Generate Image Condition On Canny Image",
43
+ description="useful when you want to generate a new real image from both the user desciption and a canny image."
44
+ " like: generate a real image of a object or something from this canny image,"
45
+ " or generate a new real image of a object or something from this edge image. "
46
+ "The input to this tool should be a comma seperated string of two, "
47
+ "representing the image_path and the user description. ")
48
+ def inference(self, inputs):
49
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
50
+ image = Image.open(image_path)
51
+ self.seed = random.randint(0, 65535)
52
+ seed_everything(self.seed)
53
+ prompt = instruct_text + ', ' + self.a_prompt
54
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
55
+ guidance_scale=9.0).images[0]
56
+ updated_image_path = get_new_image_name(image_path, func_name="canny2image")
57
+ image.save(updated_image_path)
58
+ print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, "
59
+ f"Output Text: {updated_image_path}")
60
+ return updated_image_path
modules/controlnet_depth.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Depth:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Depth")
6
+ self.depth_estimator = pipeline('depth-estimation')
7
+
8
+ @prompts(name="Predict Depth On Image",
9
+ description="useful when you want to detect depth of the image. like: generate the depth from this image, "
10
+ "or detect the depth map on this image, or predict the depth for this image. "
11
+ "The input to this tool should be a string, representing the image_path")
12
+ def inference(self, inputs):
13
+ image = Image.open(inputs)
14
+ depth = self.depth_estimator(image)['depth']
15
+ depth = np.array(depth)
16
+ depth = depth[:, :, None]
17
+ depth = np.concatenate([depth, depth, depth], axis=2)
18
+ depth = Image.fromarray(depth)
19
+ updated_image_path = get_new_image_name(inputs, func_name="depth")
20
+ depth.save(updated_image_path)
21
+ print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}")
22
+ return updated_image_path
23
+
24
+
25
+ class DepthText2Image:
26
+ def __init__(self, device, pretrained_model_dir):
27
+ print("Initializing DepthText2Image to %s" % device)
28
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
29
+ self.controlnet = ControlNetModel.from_pretrained(
30
+ f"{pretrained_model_dir}/sd-controlnet-depth", torch_dtype=self.torch_dtype)
31
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
32
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
33
+ torch_dtype=self.torch_dtype)
34
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
35
+ self.pipe.to(device)
36
+ self.seed = -1
37
+ self.a_prompt = 'best quality, extremely detailed'
38
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
39
+ ' fewer digits, cropped, worst quality, low quality'
40
+
41
+ @prompts(name="Generate Image Condition On Depth",
42
+ description="useful when you want to generate a new real image from both the user desciption and depth image. "
43
+ "like: generate a real image of a object or something from this depth image, "
44
+ "or generate a new real image of a object or something from the depth map. "
45
+ "The input to this tool should be a comma seperated string of two, "
46
+ "representing the image_path and the user description")
47
+ def inference(self, inputs):
48
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
49
+ image = Image.open(image_path)
50
+ self.seed = random.randint(0, 65535)
51
+ seed_everything(self.seed)
52
+ prompt = instruct_text + ', ' + self.a_prompt
53
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
54
+ guidance_scale=9.0).images[0]
55
+ updated_image_path = get_new_image_name(image_path, func_name="depth2image")
56
+ image.save(updated_image_path)
57
+ print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, "
58
+ f"Output Image: {updated_image_path}")
59
+ return updated_image_path
modules/controlnet_hed.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Hed:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Hed")
6
+ self.detector = HEDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
7
+
8
+ @prompts(name="Hed Detection On Image",
9
+ description="useful when you want to detect the soft hed boundary of the image. "
10
+ "like: detect the soft hed boundary of this image, or hed boundary detection on image, "
11
+ "or peform hed boundary detection on this image, or detect soft hed boundary image of this image. "
12
+ "The input to this tool should be a string, representing the image_path")
13
+ def inference(self, inputs):
14
+ image = Image.open(inputs)
15
+ hed = self.detector(image)
16
+ updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
17
+ hed.save(updated_image_path)
18
+ print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}")
19
+ return updated_image_path
20
+
21
+
22
+ class HedText2Image:
23
+ def __init__(self, device, pretrained_model_dir):
24
+ print("Initializing HedText2Image to %s" % device)
25
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
26
+ self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-hed",
27
+ torch_dtype=self.torch_dtype)
28
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
29
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
30
+ torch_dtype=self.torch_dtype
31
+ )
32
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
33
+ self.pipe.to(device)
34
+ self.seed = -1
35
+ self.a_prompt = 'best quality, extremely detailed'
36
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
37
+ 'fewer digits, cropped, worst quality, low quality'
38
+
39
+ @prompts(name="Generate Image Condition On Soft Hed Boundary Image",
40
+ description="useful when you want to generate a new real image from both the user desciption "
41
+ "and a soft hed boundary image. "
42
+ "like: generate a real image of a object or something from this soft hed boundary image, "
43
+ "or generate a new real image of a object or something from this hed boundary. "
44
+ "The input to this tool should be a comma seperated string of two, "
45
+ "representing the image_path and the user description")
46
+ def inference(self, inputs):
47
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
48
+ image = Image.open(image_path)
49
+ self.seed = random.randint(0, 65535)
50
+ seed_everything(self.seed)
51
+ prompt = instruct_text + ', ' + self.a_prompt
52
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
53
+ guidance_scale=9.0).images[0]
54
+ updated_image_path = get_new_image_name(image_path, func_name="hed2image")
55
+ image.save(updated_image_path)
56
+ print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, "
57
+ f"Output Image: {updated_image_path}")
58
+ return updated_image_path
modules/controlnet_line.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Line:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Line")
6
+ self.detector = MLSDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
7
+
8
+ @prompts(name="Line Detection On Image",
9
+ description="useful when you want to detect the straight line of the image. "
10
+ "like: detect the straight lines of this image, or straight line detection on image, "
11
+ "or peform straight line detection on this image, or detect the straight line image of this image. "
12
+ "The input to this tool should be a string, representing the image_path")
13
+ def inference(self, inputs):
14
+ image = Image.open(inputs)
15
+ mlsd = self.detector(image)
16
+ updated_image_path = get_new_image_name(inputs, func_name="line-of")
17
+ mlsd.save(updated_image_path)
18
+ print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}")
19
+ return updated_image_path
20
+
21
+
22
+ class LineText2Image:
23
+ def __init__(self, device, pretrained_model_dir):
24
+ print("Initializing LineText2Image to %s" % device)
25
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
26
+ self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-mlsd",
27
+ torch_dtype=self.torch_dtype)
28
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
29
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
30
+ torch_dtype=self.torch_dtype
31
+ )
32
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
33
+ self.pipe.to(device)
34
+ self.seed = -1
35
+ self.a_prompt = 'best quality, extremely detailed'
36
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
37
+ 'fewer digits, cropped, worst quality, low quality'
38
+
39
+ @prompts(name="Generate Image Condition On Line Image",
40
+ description="useful when you want to generate a new real image from both the user desciption "
41
+ "and a straight line image. "
42
+ "like: generate a real image of a object or something from this straight line image, "
43
+ "or generate a new real image of a object or something from this straight lines. "
44
+ "The input to this tool should be a comma seperated string of two, "
45
+ "representing the image_path and the user description. ")
46
+ def inference(self, inputs):
47
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
48
+ image = Image.open(image_path)
49
+ self.seed = random.randint(0, 65535)
50
+ seed_everything(self.seed)
51
+ prompt = instruct_text + ', ' + self.a_prompt
52
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
53
+ guidance_scale=9.0).images[0]
54
+ updated_image_path = get_new_image_name(image_path, func_name="line2image")
55
+ image.save(updated_image_path)
56
+ print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, "
57
+ f"Output Text: {updated_image_path}")
58
+ return updated_image_path
modules/controlnet_normal.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Normal:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Normal")
6
+ self.depth_estimator = pipeline("depth-estimation", model=f"{pretrained_model_dir}/dpt-hybrid-midas")
7
+ self.bg_threhold = 0.4
8
+
9
+ @prompts(name="Predict Normal Map On Image",
10
+ description="useful when you want to detect norm map of the image. "
11
+ "like: generate normal map from this image, or predict normal map of this image. "
12
+ "The input to this tool should be a string, representing the image_path")
13
+ def inference(self, inputs):
14
+ image = Image.open(inputs)
15
+ original_size = image.size
16
+ image = self.depth_estimator(image)['predicted_depth'][0]
17
+ image = image.numpy()
18
+ image_depth = image.copy()
19
+ image_depth -= np.min(image_depth)
20
+ image_depth /= np.max(image_depth)
21
+ x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
22
+ x[image_depth < self.bg_threhold] = 0
23
+ y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
24
+ y[image_depth < self.bg_threhold] = 0
25
+ z = np.ones_like(x) * np.pi * 2.0
26
+ image = np.stack([x, y, z], axis=2)
27
+ image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
28
+ image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
29
+ image = Image.fromarray(image)
30
+ image = image.resize(original_size)
31
+ updated_image_path = get_new_image_name(inputs, func_name="normal-map")
32
+ image.save(updated_image_path)
33
+ print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}")
34
+ return updated_image_path
35
+
36
+
37
+ class NormalText2Image:
38
+ def __init__(self, device, pretrained_model_dir):
39
+ print("Initializing NormalText2Image to %s" % device)
40
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
41
+ self.controlnet = ControlNetModel.from_pretrained(
42
+ f"{pretrained_model_dir}/sd-controlnet-normal", torch_dtype=self.torch_dtype)
43
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
44
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
45
+ torch_dtype=self.torch_dtype)
46
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
47
+ self.pipe.to(device)
48
+ self.seed = -1
49
+ self.a_prompt = 'best quality, extremely detailed'
50
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
51
+ ' fewer digits, cropped, worst quality, low quality'
52
+
53
+ @prompts(name="Generate Image Condition On Normal Map",
54
+ description="useful when you want to generate a new real image from both the user desciption and normal map. "
55
+ "like: generate a real image of a object or something from this normal map, "
56
+ "or generate a new real image of a object or something from the normal map. "
57
+ "The input to this tool should be a comma seperated string of two, "
58
+ "representing the image_path and the user description")
59
+ def inference(self, inputs):
60
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
61
+ image = Image.open(image_path)
62
+ self.seed = random.randint(0, 65535)
63
+ seed_everything(self.seed)
64
+ prompt = instruct_text + ', ' + self.a_prompt
65
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
66
+ guidance_scale=9.0).images[0]
67
+ updated_image_path = get_new_image_name(image_path, func_name="normal2image")
68
+ image.save(updated_image_path)
69
+ print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, "
70
+ f"Output Image: {updated_image_path}")
71
+ return updated_image_path
modules/controlnet_pose.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Pose:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Pose")
6
+ self.detector = OpenposeDetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
7
+
8
+ @prompts(name="Pose Detection On Image",
9
+ description="useful when you want to detect the human pose of the image. "
10
+ "like: generate human poses of this image, or generate a pose image from this image. "
11
+ "The input to this tool should be a string, representing the image_path")
12
+ def inference(self, inputs):
13
+ image = Image.open(inputs)
14
+ pose = self.detector(image)
15
+ updated_image_path = get_new_image_name(inputs, func_name="human-pose")
16
+ pose.save(updated_image_path)
17
+ print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
18
+ return updated_image_path
19
+
20
+
21
+ class PoseText2Image:
22
+ def __init__(self, device, pretrained_model_dir):
23
+ print("Initializing PoseText2Image to %s" % device)
24
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
25
+ self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-openpose",
26
+ torch_dtype=self.torch_dtype)
27
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
28
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
29
+ torch_dtype=self.torch_dtype)
30
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
31
+ self.pipe.to(device)
32
+ self.num_inference_steps = 20
33
+ self.seed = -1
34
+ self.unconditional_guidance_scale = 9.0
35
+ self.a_prompt = 'best quality, extremely detailed'
36
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
37
+ ' fewer digits, cropped, worst quality, low quality'
38
+
39
+ @prompts(name="Generate Image Condition On Pose Image",
40
+ description="useful when you want to generate a new real image from both the user desciption "
41
+ "and a human pose image. "
42
+ "like: generate a real image of a human from this human pose image, "
43
+ "or generate a new real image of a human from this pose. "
44
+ "The input to this tool should be a comma seperated string of two, "
45
+ "representing the image_path and the user description")
46
+ def inference(self, inputs):
47
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
48
+ image = Image.open(image_path)
49
+ self.seed = random.randint(0, 65535)
50
+ seed_everything(self.seed)
51
+ prompt = instruct_text + ', ' + self.a_prompt
52
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
53
+ guidance_scale=9.0).images[0]
54
+ updated_image_path = get_new_image_name(image_path, func_name="pose2image")
55
+ image.save(updated_image_path)
56
+ print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, "
57
+ f"Output Image: {updated_image_path}")
58
+ return updated_image_path
modules/controlnet_scibble.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Scribble:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Scribble")
6
+ self.detector = HEDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
7
+
8
+ @prompts(name="Sketch Detection On Image",
9
+ description="useful when you want to generate a scribble of the image. "
10
+ "like: generate a scribble of this image, or generate a sketch from this image, "
11
+ "detect the sketch from this image. "
12
+ "The input to this tool should be a string, representing the image_path")
13
+ def inference(self, inputs):
14
+ image = Image.open(inputs)
15
+ scribble = self.detector(image, scribble=True)
16
+ updated_image_path = get_new_image_name(inputs, func_name="scribble")
17
+ scribble.save(updated_image_path)
18
+ print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
19
+ return updated_image_path
20
+
21
+
22
+ class ScribbleText2Image:
23
+ def __init__(self, device, pretrained_model_dir):
24
+ print("Initializing ScribbleText2Image to %s" % device)
25
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
26
+ self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-scribble",
27
+ torch_dtype=self.torch_dtype)
28
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
29
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
30
+ torch_dtype=self.torch_dtype
31
+ )
32
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
33
+ self.pipe.to(device)
34
+ self.seed = -1
35
+ self.a_prompt = 'best quality, extremely detailed'
36
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
37
+ 'fewer digits, cropped, worst quality, low quality'
38
+
39
+ @prompts(name="Generate Image Condition On Sketch Image",
40
+ description="useful when you want to generate a new real image from both the user desciption and "
41
+ "a scribble image or a sketch image. "
42
+ "The input to this tool should be a comma seperated string of two, "
43
+ "representing the image_path and the user description")
44
+ def inference(self, inputs):
45
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
46
+ image = Image.open(image_path)
47
+ self.seed = random.randint(0, 65535)
48
+ seed_everything(self.seed)
49
+ prompt = instruct_text + ', ' + self.a_prompt
50
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
51
+ guidance_scale=9.0).images[0]
52
+ updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
53
+ image.save(updated_image_path)
54
+ print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
55
+ f"Output Image: {updated_image_path}")
56
+ return updated_image_path
modules/controlnet_seg.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Image2Seg:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Image2Seg")
6
+ self.image_processor = AutoImageProcessor.from_pretrained(f"{pretrained_model_dir}/upernet-convnext-small")
7
+ self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained(f"{pretrained_model_dir}/upernet-convnext-small")
8
+ self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
9
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
10
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
11
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
12
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
13
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
14
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
15
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
16
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
17
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
18
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
19
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
20
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
21
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
22
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
23
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
24
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
25
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
26
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
27
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
28
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
29
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
30
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
31
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
32
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
33
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
34
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
35
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
36
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
37
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
38
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
39
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
40
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
41
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
42
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
43
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
44
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
45
+ [102, 255, 0], [92, 0, 255]]
46
+
47
+ @prompts(name="Segmentation On Image",
48
+ description="useful when you want to detect segmentations of the image. "
49
+ "like: segment this image, or generate segmentations on this image, "
50
+ "or peform segmentation on this image. "
51
+ "The input to this tool should be a string, representing the image_path")
52
+ def inference(self, inputs):
53
+ image = Image.open(inputs)
54
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
55
+ with torch.no_grad():
56
+ outputs = self.image_segmentor(pixel_values)
57
+ seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
58
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
59
+ palette = np.array(self.ade_palette)
60
+ for label, color in enumerate(palette):
61
+ color_seg[seg == label, :] = color
62
+ color_seg = color_seg.astype(np.uint8)
63
+ segmentation = Image.fromarray(color_seg)
64
+ updated_image_path = get_new_image_name(inputs, func_name="segmentation")
65
+ segmentation.save(updated_image_path)
66
+ print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
67
+ return updated_image_path
68
+
69
+
70
+ class SegText2Image:
71
+ def __init__(self, device, pretrained_model_dir):
72
+ print("Initializing SegText2Image to %s" % device)
73
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
74
+ self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-seg",
75
+ torch_dtype=self.torch_dtype)
76
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
77
+ f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
78
+ torch_dtype=self.torch_dtype)
79
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
80
+ self.pipe.to(device)
81
+ self.seed = -1
82
+ self.a_prompt = 'best quality, extremely detailed'
83
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
84
+ ' fewer digits, cropped, worst quality, low quality'
85
+
86
+ @prompts(name="Generate Image Condition On Segmentations",
87
+ description="useful when you want to generate a new real image from both the user desciption and segmentations. "
88
+ "like: generate a real image of a object or something from this segmentation image, "
89
+ "or generate a new real image of a object or something from these segmentations. "
90
+ "The input to this tool should be a comma seperated string of two, "
91
+ "representing the image_path and the user description")
92
+ def inference(self, inputs):
93
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
94
+ image = Image.open(image_path)
95
+ self.seed = random.randint(0, 65535)
96
+ seed_everything(self.seed)
97
+ prompt = instruct_text + ', ' + self.a_prompt
98
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
99
+ guidance_scale=9.0).images[0]
100
+ updated_image_path = get_new_image_name(image_path, func_name="segment2image")
101
+ image.save(updated_image_path)
102
+ print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, "
103
+ f"Output Image: {updated_image_path}")
104
+ return updated_image_path
modules/image_captioning.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.utils import *
3
+
4
+ class ImageCaptioning:
5
+ def __init__(self, device, pretrained_model_dir):
6
+ print("Initializing ImageCaptioning to %s" % device)
7
+ self.device = device
8
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
9
+ self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-image-captioning-base")
10
+ self.model = BlipForConditionalGeneration.from_pretrained(
11
+ f"{pretrained_model_dir}/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
12
+
13
+ @prompts(name="Get Photo Description",
14
+ description="useful when you want to know what is inside the photo. receives image_path as input. "
15
+ "The input to this tool should be a string, representing the image_path. ")
16
+ def inference(self, image_path):
17
+ inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
18
+ out = self.model.generate(**inputs)
19
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
20
+ print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
21
+ return captions
modules/image_editing.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class ImageEditing:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing ImageEditing to %s" % device)
6
+ self.device = device
7
+ self.mask_former = MaskFormer(device=self.device, pretrained_model_dir=pretrained_model_dir)
8
+ self.revision = 'fp16' if 'cuda' in device else None
9
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
10
+ self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
11
+ f"{pretrained_model_dir}/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
12
+
13
+ @prompts(name="Remove Something From The Photo",
14
+ description="useful when you want to remove and object or something from the photo "
15
+ "from its description or location. "
16
+ "The input to this tool should be a comma seperated string of two, "
17
+ "representing the image_path and the object need to be removed. ")
18
+ def inference_remove(self, inputs):
19
+ image_path, to_be_removed_txt = inputs.split(",")
20
+ return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
21
+
22
+ @prompts(name="Replace Something From The Photo",
23
+ description="useful when you want to replace an object from the object description or "
24
+ "location with another object from its description. "
25
+ "The input to this tool should be a comma seperated string of three, "
26
+ "representing the image_path, the object to be replaced, the object to be replaced with ")
27
+ def inference_replace(self, inputs):
28
+ image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
29
+ original_image = Image.open(image_path)
30
+ original_size = original_image.size
31
+ mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
32
+ updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)),
33
+ mask_image=mask_image.resize((512, 512))).images[0]
34
+ updated_image_path = get_new_image_name(image_path, func_name="replace-something")
35
+ updated_image = updated_image.resize(original_size)
36
+ updated_image.save(updated_image_path)
37
+ print(
38
+ f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
39
+ f"Output Image: {updated_image_path}")
40
+ return updated_image_path
modules/instruct_px2pix.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class InstructPix2Pix:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing InstructPix2Pix to %s" % device)
6
+ self.device = device
7
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
8
+ self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(f"{pretrained_model_dir}/instruct-pix2pix",
9
+ safety_checker=None,
10
+ torch_dtype=self.torch_dtype).to(device)
11
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
12
+
13
+ @prompts(name="Instruct Image Using Text",
14
+ description="useful when you want to the style of the image to be like the text. "
15
+ "like: make it look like a painting. or make it like a robot. "
16
+ "The input to this tool should be a comma seperated string of two, "
17
+ "representing the image_path and the text. ")
18
+ def inference(self, inputs):
19
+ """Change style of image."""
20
+ print("===>Starting InstructPix2Pix Inference")
21
+ image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
22
+ original_image = Image.open(image_path)
23
+ image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
24
+ updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
25
+ image.save(updated_image_path)
26
+ print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
27
+ f"Output Image: {updated_image_path}")
28
+ return updated_image_path
modules/mask_former.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class MaskFormer:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing MaskFormer to %s" % device)
6
+ self.device = device
7
+ self.processor = CLIPSegProcessor.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined")
8
+ self.model = CLIPSegForImageSegmentation.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined").to(device)
9
+
10
+ def inference(self, image_path, text):
11
+ threshold = 0.5
12
+ min_area = 0.02
13
+ padding = 20
14
+ original_image = Image.open(image_path)
15
+ image = original_image.resize((512, 512))
16
+ inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
17
+ with torch.no_grad():
18
+ outputs = self.model(**inputs)
19
+ mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
20
+ area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
21
+ if area_ratio < min_area:
22
+ return None
23
+ true_indices = np.argwhere(mask)
24
+ mask_array = np.zeros_like(mask, dtype=bool)
25
+ for idx in true_indices:
26
+ padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
27
+ mask_array[padded_slice] = True
28
+ visual_mask = (mask_array * 255).astype(np.uint8)
29
+ image_mask = Image.fromarray(visual_mask)
30
+ return image_mask.resize(original_image.size)
modules/text2img.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class Text2Image:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing Text2Image to %s" % device)
6
+ self.device = device
7
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
8
+ self.pipe = StableDiffusionPipeline.from_pretrained(f"{pretrained_model_dir}/stable-diffusion-v1-5",
9
+ torch_dtype=self.torch_dtype)
10
+ self.pipe.to(device)
11
+ self.a_prompt = 'best quality, extremely detailed'
12
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
13
+ 'fewer digits, cropped, worst quality, low quality'
14
+
15
+ @prompts(name="Generate Image From User Input Text",
16
+ description="useful when you want to generate an image from a user input text and save it to a file. "
17
+ "like: generate an image of an object or something, or generate an image that includes some objects. "
18
+ "The input to this tool should be a string, representing the text used to generate image. ")
19
+ def inference(self, text):
20
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
21
+ prompt = text + ', ' + self.a_prompt
22
+ image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
23
+ image.save(image_filename)
24
+ print(
25
+ f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
26
+ return image_filename
modules/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import gradio as gr
4
+ import random
5
+ import torch
6
+ import cv2
7
+ import re
8
+ import uuid
9
+ from PIL import Image
10
+ import numpy as np
11
+ import argparse
12
+
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
14
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
15
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
16
+
17
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
18
+ from diffusers import EulerAncestralDiscreteScheduler
19
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
20
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
21
+
22
+ from langchain.agents.initialize import initialize_agent
23
+ from langchain.agents.tools import Tool
24
+ from langchain.chains.conversation.memory import ConversationBufferMemory
25
+ from langchain.llms.openai import OpenAI
26
+
27
+ # 装饰器
28
+ def prompts(name, description):
29
+ def decorator(func):
30
+ func.name = name
31
+ func.description = description
32
+ return func
33
+
34
+ return decorator
35
+
36
+ # 设置种子
37
+ def seed_everything(seed):
38
+ random.seed(seed)
39
+ np.random.seed(seed)
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ return seed
43
+
44
+ # 对话历史截断
45
+ def cut_dialogue_history(history_memory, keep_last_n_words=500):
46
+ tokens = history_memory.split()
47
+ n_tokens = len(tokens)
48
+ print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
49
+ if n_tokens < keep_last_n_words:
50
+ return history_memory
51
+ else:
52
+ paragraphs = history_memory.split('\n')
53
+ last_n_tokens = n_tokens
54
+ while last_n_tokens >= keep_last_n_words:
55
+ last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
56
+ paragraphs = paragraphs[1:]
57
+ return '\n' + '\n'.join(paragraphs)
58
+
59
+ # 获取新图片
60
+ def get_new_image_name(org_img_name, func_name="update"):
61
+ head_tail = os.path.split(org_img_name)
62
+ head = head_tail[0]
63
+ tail = head_tail[1]
64
+ name_split = tail.split('.')[0].split('_')
65
+ this_new_uuid = str(uuid.uuid4())[0:4]
66
+ if len(name_split) == 1:
67
+ most_org_file_name = name_split[0]
68
+ recent_prev_file_name = name_split[0]
69
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
70
+ else:
71
+ assert len(name_split) == 4
72
+ most_org_file_name = name_split[3]
73
+ recent_prev_file_name = name_split[0]
74
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
75
+ return os.path.join(head, new_file_name)
modules/visual_question_answering.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils import *
2
+
3
+ class VisualQuestionAnswering:
4
+ def __init__(self, device, pretrained_model_dir):
5
+ print("Initializing VisualQuestionAnswering to %s" % device)
6
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
7
+ self.device = device
8
+ self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-vqa-base")
9
+ self.model = BlipForQuestionAnswering.from_pretrained(
10
+ f"{pretrained_model_dir}/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
11
+
12
+ @prompts(name="Answer Question About The Image",
13
+ description="useful when you need an answer for a question based on an image. "
14
+ "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
15
+ "The input to this tool should be a comma seperated string of two, representing the image_path and the question")
16
+ def inference(self, inputs):
17
+ image_path, question = inputs.split(",")
18
+ raw_image = Image.open(image_path).convert('RGB')
19
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
20
+ out = self.model.generate(**inputs)
21
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
22
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
23
+ f"Output Answer: {answer}")
24
+ return answer
requirement.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain==0.0.101
2
+ torch==1.12.1
3
+ torchvision==0.13.1
4
+ gradio==3.20.1
5
+ accelerate
6
+ addict
7
+ albumentations
8
+ basicsr
9
+ controlnet-aux
10
+ diffusers
11
+ einops
12
+ imageio
13
+ imageio-ffmpeg
14
+ invisible-watermark
15
+ kornia
16
+ numpy
17
+ omegaconf
18
+ open_clip_torch
19
+ openai
20
+ opencv-python
21
+ prettytable
22
+ safetensors
23
+ streamlit
24
+ test-tube
25
+ timm
26
+ torchmetrics
27
+ transformers
28
+ webdataset
29
+ yapf
30
+ numba
31
+ librosa
32
+ scipy
33
+ unidecode
34
+ openjtalk>=0.3.0.dev2
35
+ jamo
36
+ pypinyin
37
+ jieba
38
+ protobuf
39
+ pygtrans
40
+ cn2an
41
+ inflect
42
+ eng_to_ipa
43
+ ko_pron
44
+ indic_transliteration
45
+ num_thai
46
+ opencc
47
+ vosk
48
+ sounddevice
text/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+
4
+
5
+ def text_to_sequence(text, symbols, cleaner_names):
6
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
7
+ Args:
8
+ text: string to convert to a sequence
9
+ cleaner_names: names of the cleaner functions to run the text through
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ '''
13
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
14
+
15
+ sequence = []
16
+
17
+ clean_text = _clean_text(text, cleaner_names)
18
+ for symbol in clean_text:
19
+ if symbol not in _symbol_to_id.keys():
20
+ continue
21
+ symbol_id = _symbol_to_id[symbol]
22
+ sequence += [symbol_id]
23
+ return sequence
24
+
25
+
26
+ def _clean_text(text, cleaner_names):
27
+ for name in cleaner_names:
28
+ cleaner = getattr(cleaners, name)
29
+ if not cleaner:
30
+ raise Exception('Unknown cleaner: %s' % name)
31
+ text = cleaner(text)
32
+ return text
text/cantonese.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import cn2an
3
+ import opencc
4
+
5
+
6
+ converter = opencc.OpenCC('jyutjyu')
7
+
8
+ # List of (Latin alphabet, ipa) pairs:
9
+ _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
10
+ ('A', 'ei˥'),
11
+ ('B', 'biː˥'),
12
+ ('C', 'siː˥'),
13
+ ('D', 'tiː˥'),
14
+ ('E', 'iː˥'),
15
+ ('F', 'e˥fuː˨˩'),
16
+ ('G', 'tsiː˥'),
17
+ ('H', 'ɪk̚˥tsʰyː˨˩'),
18
+ ('I', 'ɐi˥'),
19
+ ('J', 'tsei˥'),
20
+ ('K', 'kʰei˥'),
21
+ ('L', 'e˥llou˨˩'),
22
+ ('M', 'ɛːm˥'),
23
+ ('N', 'ɛːn˥'),
24
+ ('O', 'ou˥'),
25
+ ('P', 'pʰiː˥'),
26
+ ('Q', 'kʰiːu˥'),
27
+ ('R', 'aː˥lou˨˩'),
28
+ ('S', 'ɛː˥siː˨˩'),
29
+ ('T', 'tʰiː˥'),
30
+ ('U', 'juː˥'),
31
+ ('V', 'wiː˥'),
32
+ ('W', 'tʊk̚˥piː˥juː˥'),
33
+ ('X', 'ɪk̚˥siː˨˩'),
34
+ ('Y', 'waːi˥'),
35
+ ('Z', 'iː˨sɛːt̚˥')
36
+ ]]
37
+
38
+
39
+ def number_to_cantonese(text):
40
+ return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text)
41
+
42
+
43
+ def latin_to_ipa(text):
44
+ for regex, replacement in _latin_to_ipa:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def cantonese_to_ipa(text):
50
+ text = number_to_cantonese(text.upper())
51
+ text = converter.convert(text).replace('-','').replace('$',' ')
52
+ text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
53
+ text = re.sub(r'[、;:]', ',', text)
54
+ text = re.sub(r'\s*,\s*', ', ', text)
55
+ text = re.sub(r'\s*。\s*', '. ', text)
56
+ text = re.sub(r'\s*?\s*', '? ', text)
57
+ text = re.sub(r'\s*!\s*', '! ', text)
58
+ text = re.sub(r'\s*$', '', text)
59
+ return text
text/cleaners.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def japanese_cleaners(text):
5
+ from text.japanese import japanese_to_romaji_with_accent
6
+ text = japanese_to_romaji_with_accent(text)
7
+ text = re.sub(r'([A-Za-z])$', r'\1.', text)
8
+ return text
9
+
10
+
11
+ def japanese_cleaners2(text):
12
+ return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
13
+
14
+
15
+ def korean_cleaners(text):
16
+ '''Pipeline for Korean text'''
17
+ from text.korean import latin_to_hangul, number_to_hangul, divide_hangul
18
+ text = latin_to_hangul(text)
19
+ text = number_to_hangul(text)
20
+ text = divide_hangul(text)
21
+ text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
22
+ return text
23
+
24
+
25
+ def chinese_cleaners(text):
26
+ '''Pipeline for Chinese text'''
27
+ from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo
28
+ text = number_to_chinese(text)
29
+ text = chinese_to_bopomofo(text)
30
+ text = latin_to_bopomofo(text)
31
+ text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
32
+ return text
33
+
34
+
35
+ def zh_ja_mixture_cleaners(text):
36
+ from text.mandarin import chinese_to_romaji
37
+ from text.japanese import japanese_to_romaji_with_accent
38
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
39
+ lambda x: chinese_to_romaji(x.group(1))+' ', text)
40
+ text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent(
41
+ x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…')+' ', text)
42
+ text = re.sub(r'\s+$', '', text)
43
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
44
+ return text
45
+
46
+
47
+ def sanskrit_cleaners(text):
48
+ text = text.replace('॥', '।').replace('ॐ', 'ओम्')
49
+ text = re.sub(r'([^।])$', r'\1।', text)
50
+ return text
51
+
52
+
53
+ def cjks_cleaners(text):
54
+ from text.mandarin import chinese_to_lazy_ipa
55
+ from text.japanese import japanese_to_ipa
56
+ from text.korean import korean_to_lazy_ipa
57
+ from text.sanskrit import devanagari_to_ipa
58
+ from text.english import english_to_lazy_ipa
59
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
60
+ lambda x: chinese_to_lazy_ipa(x.group(1))+' ', text)
61
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
62
+ lambda x: japanese_to_ipa(x.group(1))+' ', text)
63
+ text = re.sub(r'\[KO\](.*?)\[KO\]',
64
+ lambda x: korean_to_lazy_ipa(x.group(1))+' ', text)
65
+ text = re.sub(r'\[SA\](.*?)\[SA\]',
66
+ lambda x: devanagari_to_ipa(x.group(1))+' ', text)
67
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
68
+ lambda x: english_to_lazy_ipa(x.group(1))+' ', text)
69
+ text = re.sub(r'\s+$', '', text)
70
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
71
+ return text
72
+
73
+
74
+ def cjke_cleaners(text):
75
+ from text.mandarin import chinese_to_lazy_ipa
76
+ from text.japanese import japanese_to_ipa
77
+ from text.korean import korean_to_ipa
78
+ from text.english import english_to_ipa2
79
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace(
80
+ 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')+' ', text)
81
+ text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace(
82
+ 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')+' ', text)
83
+ text = re.sub(r'\[KO\](.*?)\[KO\]',
84
+ lambda x: korean_to_ipa(x.group(1))+' ', text)
85
+ text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace(
86
+ 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')+' ', text)
87
+ text = re.sub(r'\s+$', '', text)
88
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
89
+ return text
90
+
91
+
92
+ def cjke_cleaners2(text):
93
+ from text.mandarin import chinese_to_ipa
94
+ from text.japanese import japanese_to_ipa2
95
+ from text.korean import korean_to_ipa
96
+ from text.english import english_to_ipa2
97
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
98
+ lambda x: chinese_to_ipa(x.group(1))+' ', text)
99
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
100
+ lambda x: japanese_to_ipa2(x.group(1))+' ', text)
101
+ text = re.sub(r'\[KO\](.*?)\[KO\]',
102
+ lambda x: korean_to_ipa(x.group(1))+' ', text)
103
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
104
+ lambda x: english_to_ipa2(x.group(1))+' ', text)
105
+ text = re.sub(r'\s+$', '', text)
106
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
107
+ return text
108
+
109
+
110
+ def thai_cleaners(text):
111
+ from text.thai import num_to_thai, latin_to_thai
112
+ text = num_to_thai(text)
113
+ text = latin_to_thai(text)
114
+ return text
115
+
116
+
117
+ def shanghainese_cleaners(text):
118
+ from text.shanghainese import shanghainese_to_ipa
119
+ text = shanghainese_to_ipa(text)
120
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
121
+ return text
122
+
123
+
124
+ def chinese_dialect_cleaners(text):
125
+ from text.mandarin import chinese_to_ipa2
126
+ from text.japanese import japanese_to_ipa3
127
+ from text.shanghainese import shanghainese_to_ipa
128
+ from text.cantonese import cantonese_to_ipa
129
+ from text.english import english_to_lazy_ipa2
130
+ from text.ngu_dialect import ngu_dialect_to_ipa
131
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
132
+ lambda x: chinese_to_ipa2(x.group(1))+' ', text)
133
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
134
+ lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text)
135
+ text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5',
136
+ '˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text)
137
+ text = re.sub(r'\[GD\](.*?)\[GD\]',
138
+ lambda x: cantonese_to_ipa(x.group(1))+' ', text)
139
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
140
+ lambda x: english_to_lazy_ipa2(x.group(1))+' ', text)
141
+ text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group(
142
+ 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text)
143
+ text = re.sub(r'\s+$', '', text)
144
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
145
+ return text
text/english.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ '''
14
+
15
+
16
+ # Regular expression matching whitespace:
17
+
18
+
19
+ import re
20
+ import inflect
21
+ from unidecode import unidecode
22
+ import eng_to_ipa as ipa
23
+ _inflect = inflect.engine()
24
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29
+ _number_re = re.compile(r'[0-9]+')
30
+
31
+ # List of (regular expression, replacement) pairs for abbreviations:
32
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33
+ ('mrs', 'misess'),
34
+ ('mr', 'mister'),
35
+ ('dr', 'doctor'),
36
+ ('st', 'saint'),
37
+ ('co', 'company'),
38
+ ('jr', 'junior'),
39
+ ('maj', 'major'),
40
+ ('gen', 'general'),
41
+ ('drs', 'doctors'),
42
+ ('rev', 'reverend'),
43
+ ('lt', 'lieutenant'),
44
+ ('hon', 'honorable'),
45
+ ('sgt', 'sergeant'),
46
+ ('capt', 'captain'),
47
+ ('esq', 'esquire'),
48
+ ('ltd', 'limited'),
49
+ ('col', 'colonel'),
50
+ ('ft', 'fort'),
51
+ ]]
52
+
53
+
54
+ # List of (ipa, lazy ipa) pairs:
55
+ _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56
+ ('r', 'ɹ'),
57
+ ('æ', 'e'),
58
+ ('ɑ', 'a'),
59
+ ('ɔ', 'o'),
60
+ ('ð', 'z'),
61
+ ('θ', 's'),
62
+ ('ɛ', 'e'),
63
+ ('ɪ', 'i'),
64
+ ('ʊ', 'u'),
65
+ ('ʒ', 'ʥ'),
66
+ ('ʤ', 'ʥ'),
67
+ ('ˈ', '↓'),
68
+ ]]
69
+
70
+ # List of (ipa, lazy ipa2) pairs:
71
+ _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72
+ ('r', 'ɹ'),
73
+ ('ð', 'z'),
74
+ ('θ', 's'),
75
+ ('ʒ', 'ʑ'),
76
+ ('ʤ', 'dʑ'),
77
+ ('ˈ', '↓'),
78
+ ]]
79
+
80
+ # List of (ipa, ipa2) pairs
81
+ _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82
+ ('r', 'ɹ'),
83
+ ('ʤ', 'dʒ'),
84
+ ('ʧ', 'tʃ')
85
+ ]]
86
+
87
+
88
+ def expand_abbreviations(text):
89
+ for regex, replacement in _abbreviations:
90
+ text = re.sub(regex, replacement, text)
91
+ return text
92
+
93
+
94
+ def collapse_whitespace(text):
95
+ return re.sub(r'\s+', ' ', text)
96
+
97
+
98
+ def _remove_commas(m):
99
+ return m.group(1).replace(',', '')
100
+
101
+
102
+ def _expand_decimal_point(m):
103
+ return m.group(1).replace('.', ' point ')
104
+
105
+
106
+ def _expand_dollars(m):
107
+ match = m.group(1)
108
+ parts = match.split('.')
109
+ if len(parts) > 2:
110
+ return match + ' dollars' # Unexpected format
111
+ dollars = int(parts[0]) if parts[0] else 0
112
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113
+ if dollars and cents:
114
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115
+ cent_unit = 'cent' if cents == 1 else 'cents'
116
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117
+ elif dollars:
118
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119
+ return '%s %s' % (dollars, dollar_unit)
120
+ elif cents:
121
+ cent_unit = 'cent' if cents == 1 else 'cents'
122
+ return '%s %s' % (cents, cent_unit)
123
+ else:
124
+ return 'zero dollars'
125
+
126
+
127
+ def _expand_ordinal(m):
128
+ return _inflect.number_to_words(m.group(0))
129
+
130
+
131
+ def _expand_number(m):
132
+ num = int(m.group(0))
133
+ if num > 1000 and num < 3000:
134
+ if num == 2000:
135
+ return 'two thousand'
136
+ elif num > 2000 and num < 2010:
137
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
138
+ elif num % 100 == 0:
139
+ return _inflect.number_to_words(num // 100) + ' hundred'
140
+ else:
141
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142
+ else:
143
+ return _inflect.number_to_words(num, andword='')
144
+
145
+
146
+ def normalize_numbers(text):
147
+ text = re.sub(_comma_number_re, _remove_commas, text)
148
+ text = re.sub(_pounds_re, r'\1 pounds', text)
149
+ text = re.sub(_dollars_re, _expand_dollars, text)
150
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
152
+ text = re.sub(_number_re, _expand_number, text)
153
+ return text
154
+
155
+
156
+ def mark_dark_l(text):
157
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158
+
159
+
160
+ def english_to_ipa(text):
161
+ text = unidecode(text).lower()
162
+ text = expand_abbreviations(text)
163
+ text = normalize_numbers(text)
164
+ phonemes = ipa.convert(text)
165
+ phonemes = collapse_whitespace(phonemes)
166
+ return phonemes
167
+
168
+
169
+ def english_to_lazy_ipa(text):
170
+ text = english_to_ipa(text)
171
+ for regex, replacement in _lazy_ipa:
172
+ text = re.sub(regex, replacement, text)
173
+ return text
174
+
175
+
176
+ def english_to_ipa2(text):
177
+ text = english_to_ipa(text)
178
+ text = mark_dark_l(text)
179
+ for regex, replacement in _ipa_to_ipa2:
180
+ text = re.sub(regex, replacement, text)
181
+ return text.replace('...', '…')
182
+
183
+
184
+ def english_to_lazy_ipa2(text):
185
+ text = english_to_ipa(text)
186
+ for regex, replacement in _lazy_ipa2:
187
+ text = re.sub(regex, replacement, text)
188
+ return text
text/japanese.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unidecode import unidecode
3
+ import text.pyopenjtalk as pyopenjtalk
4
+
5
+
6
+ # Regular expression matching Japanese without punctuation marks:
7
+ _japanese_characters = re.compile(
8
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9
+
10
+ # Regular expression matching non-Japanese characters or punctuation marks:
11
+ _japanese_marks = re.compile(
12
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13
+
14
+ # List of (symbol, Japanese) pairs for marks:
15
+ _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16
+ ('%', 'パーセント')
17
+ ]]
18
+
19
+ # List of (romaji, ipa) pairs for marks:
20
+ _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21
+ ('ts', 'ʦ'),
22
+ ('u', 'ɯ'),
23
+ ('j', 'ʥ'),
24
+ ('y', 'j'),
25
+ ('ni', 'n^i'),
26
+ ('nj', 'n^'),
27
+ ('hi', 'çi'),
28
+ ('hj', 'ç'),
29
+ ('f', 'ɸ'),
30
+ ('I', 'i*'),
31
+ ('U', 'ɯ*'),
32
+ ('r', 'ɾ')
33
+ ]]
34
+
35
+ # List of (romaji, ipa2) pairs for marks:
36
+ _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37
+ ('u', 'ɯ'),
38
+ ('ʧ', 'tʃ'),
39
+ ('j', 'dʑ'),
40
+ ('y', 'j'),
41
+ ('ni', 'n^i'),
42
+ ('nj', 'n^'),
43
+ ('hi', 'çi'),
44
+ ('hj', 'ç'),
45
+ ('f', 'ɸ'),
46
+ ('I', 'i*'),
47
+ ('U', 'ɯ*'),
48
+ ('r', 'ɾ')
49
+ ]]
50
+
51
+ # List of (consonant, sokuon) pairs:
52
+ _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53
+ (r'Q([↑↓]*[kg])', r'k#\1'),
54
+ (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55
+ (r'Q([↑↓]*[sʃ])', r's\1'),
56
+ (r'Q([↑↓]*[pb])', r'p#\1')
57
+ ]]
58
+
59
+ # List of (consonant, hatsuon) pairs:
60
+ _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61
+ (r'N([↑↓]*[pbm])', r'm\1'),
62
+ (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63
+ (r'N([↑↓]*[tdn])', r'n\1'),
64
+ (r'N([↑↓]*[kg])', r'ŋ\1')
65
+ ]]
66
+
67
+
68
+ def symbols_to_japanese(text):
69
+ for regex, replacement in _symbols_to_japanese:
70
+ text = re.sub(regex, replacement, text)
71
+ return text
72
+
73
+
74
+ def japanese_to_romaji_with_accent(text):
75
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76
+ text = symbols_to_japanese(text)
77
+ sentences = re.split(_japanese_marks, text)
78
+ marks = re.findall(_japanese_marks, text)
79
+ text = ''
80
+ for i, sentence in enumerate(sentences):
81
+ if re.match(_japanese_characters, sentence):
82
+ if text != '':
83
+ text += ' '
84
+ labels = pyopenjtalk.extract_fullcontext(sentence)
85
+ for n, label in enumerate(labels):
86
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
87
+ if phoneme not in ['sil', 'pau']:
88
+ text += phoneme.replace('ch', 'ʧ').replace('sh',
89
+ 'ʃ').replace('cl', 'Q')
90
+ else:
91
+ continue
92
+ # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
93
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
94
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
95
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
96
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
97
+ a2_next = -1
98
+ else:
99
+ a2_next = int(
100
+ re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
101
+ # Accent phrase boundary
102
+ if a3 == 1 and a2_next == 1:
103
+ text += ' '
104
+ # Falling
105
+ elif a1 == 0 and a2_next == a2 + 1:
106
+ text += '↓'
107
+ # Rising
108
+ elif a2 == 1 and a2_next == 2:
109
+ text += '↑'
110
+ if i < len(marks):
111
+ text += unidecode(marks[i]).replace(' ', '')
112
+ return text
113
+
114
+
115
+ def get_real_sokuon(text):
116
+ for regex, replacement in _real_sokuon:
117
+ text = re.sub(regex, replacement, text)
118
+ return text
119
+
120
+
121
+ def get_real_hatsuon(text):
122
+ for regex, replacement in _real_hatsuon:
123
+ text = re.sub(regex, replacement, text)
124
+ return text
125
+
126
+
127
+ def japanese_to_ipa(text):
128
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
129
+ text = re.sub(
130
+ r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131
+ text = get_real_sokuon(text)
132
+ text = get_real_hatsuon(text)
133
+ for regex, replacement in _romaji_to_ipa:
134
+ text = re.sub(regex, replacement, text)
135
+ return text
136
+
137
+
138
+ def japanese_to_ipa2(text):
139
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
140
+ text = get_real_sokuon(text)
141
+ text = get_real_hatsuon(text)
142
+ for regex, replacement in _romaji_to_ipa2:
143
+ text = re.sub(regex, replacement, text)
144
+ return text
145
+
146
+
147
+ def japanese_to_ipa3(text):
148
+ text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149
+ 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150
+ text = re.sub(
151
+ r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152
+ text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153
+ return text
text/korean.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from jamo import h2j, j2hcj
3
+ import ko_pron
4
+
5
+
6
+ # This is a list of Korean classifiers preceded by pure Korean numerals.
7
+ _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
8
+
9
+ # List of (hangul, hangul divided) pairs:
10
+ _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
11
+ ('ㄳ', 'ㄱㅅ'),
12
+ ('ㄵ', 'ㄴㅈ'),
13
+ ('ㄶ', 'ㄴㅎ'),
14
+ ('ㄺ', 'ㄹㄱ'),
15
+ ('ㄻ', 'ㄹㅁ'),
16
+ ('ㄼ', 'ㄹㅂ'),
17
+ ('ㄽ', 'ㄹㅅ'),
18
+ ('ㄾ', 'ㄹㅌ'),
19
+ ('ㄿ', 'ㄹㅍ'),
20
+ ('ㅀ', 'ㄹㅎ'),
21
+ ('ㅄ', 'ㅂㅅ'),
22
+ ('ㅘ', 'ㅗㅏ'),
23
+ ('ㅙ', 'ㅗㅐ'),
24
+ ('ㅚ', 'ㅗㅣ'),
25
+ ('ㅝ', 'ㅜㅓ'),
26
+ ('ㅞ', 'ㅜㅔ'),
27
+ ('ㅟ', 'ㅜㅣ'),
28
+ ('ㅢ', 'ㅡㅣ'),
29
+ ('ㅑ', 'ㅣㅏ'),
30
+ ('ㅒ', 'ㅣㅐ'),
31
+ ('ㅕ', 'ㅣㅓ'),
32
+ ('ㅖ', 'ㅣㅔ'),
33
+ ('ㅛ', 'ㅣㅗ'),
34
+ ('ㅠ', 'ㅣㅜ')
35
+ ]]
36
+
37
+ # List of (Latin alphabet, hangul) pairs:
38
+ _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
39
+ ('a', '에이'),
40
+ ('b', '비'),
41
+ ('c', '시'),
42
+ ('d', '디'),
43
+ ('e', '이'),
44
+ ('f', '에프'),
45
+ ('g', '지'),
46
+ ('h', '에이치'),
47
+ ('i', '아이'),
48
+ ('j', '제이'),
49
+ ('k', '케이'),
50
+ ('l', '엘'),
51
+ ('m', '엠'),
52
+ ('n', '엔'),
53
+ ('o', '오'),
54
+ ('p', '피'),
55
+ ('q', '큐'),
56
+ ('r', '아르'),
57
+ ('s', '에스'),
58
+ ('t', '티'),
59
+ ('u', '유'),
60
+ ('v', '브이'),
61
+ ('w', '더블유'),
62
+ ('x', '엑스'),
63
+ ('y', '와이'),
64
+ ('z', '제트')
65
+ ]]
66
+
67
+ # List of (ipa, lazy ipa) pairs:
68
+ _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
69
+ ('t͡ɕ','ʧ'),
70
+ ('d͡ʑ','ʥ'),
71
+ ('ɲ','n^'),
72
+ ('ɕ','ʃ'),
73
+ ('ʷ','w'),
74
+ ('ɭ','l`'),
75
+ ('ʎ','ɾ'),
76
+ ('ɣ','ŋ'),
77
+ ('ɰ','ɯ'),
78
+ ('ʝ','j'),
79
+ ('ʌ','ə'),
80
+ ('ɡ','g'),
81
+ ('\u031a','#'),
82
+ ('\u0348','='),
83
+ ('\u031e',''),
84
+ ('\u0320',''),
85
+ ('\u0339','')
86
+ ]]
87
+
88
+
89
+ def latin_to_hangul(text):
90
+ for regex, replacement in _latin_to_hangul:
91
+ text = re.sub(regex, replacement, text)
92
+ return text
93
+
94
+
95
+ def divide_hangul(text):
96
+ text = j2hcj(h2j(text))
97
+ for regex, replacement in _hangul_divided:
98
+ text = re.sub(regex, replacement, text)
99
+ return text
100
+
101
+
102
+ def hangul_number(num, sino=True):
103
+ '''Reference https://github.com/Kyubyong/g2pK'''
104
+ num = re.sub(',', '', num)
105
+
106
+ if num == '0':
107
+ return '영'
108
+ if not sino and num == '20':
109
+ return '스무'
110
+
111
+ digits = '123456789'
112
+ names = '일이삼사오육칠팔구'
113
+ digit2name = {d: n for d, n in zip(digits, names)}
114
+
115
+ modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
116
+ decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
117
+ digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
118
+ digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
119
+
120
+ spelledout = []
121
+ for i, digit in enumerate(num):
122
+ i = len(num) - i - 1
123
+ if sino:
124
+ if i == 0:
125
+ name = digit2name.get(digit, '')
126
+ elif i == 1:
127
+ name = digit2name.get(digit, '') + '십'
128
+ name = name.replace('일십', '십')
129
+ else:
130
+ if i == 0:
131
+ name = digit2mod.get(digit, '')
132
+ elif i == 1:
133
+ name = digit2dec.get(digit, '')
134
+ if digit == '0':
135
+ if i % 4 == 0:
136
+ last_three = spelledout[-min(3, len(spelledout)):]
137
+ if ''.join(last_three) == '':
138
+ spelledout.append('')
139
+ continue
140
+ else:
141
+ spelledout.append('')
142
+ continue
143
+ if i == 2:
144
+ name = digit2name.get(digit, '') + '백'
145
+ name = name.replace('일백', '백')
146
+ elif i == 3:
147
+ name = digit2name.get(digit, '') + '천'
148
+ name = name.replace('일천', '천')
149
+ elif i == 4:
150
+ name = digit2name.get(digit, '') + '만'
151
+ name = name.replace('일만', '만')
152
+ elif i == 5:
153
+ name = digit2name.get(digit, '') + '십'
154
+ name = name.replace('일십', '십')
155
+ elif i == 6:
156
+ name = digit2name.get(digit, '') + '백'
157
+ name = name.replace('일백', '백')
158
+ elif i == 7:
159
+ name = digit2name.get(digit, '') + '천'
160
+ name = name.replace('일천', '천')
161
+ elif i == 8:
162
+ name = digit2name.get(digit, '') + '억'
163
+ elif i == 9:
164
+ name = digit2name.get(digit, '') + '십'
165
+ elif i == 10:
166
+ name = digit2name.get(digit, '') + '백'
167
+ elif i == 11:
168
+ name = digit2name.get(digit, '') + '천'
169
+ elif i == 12:
170
+ name = digit2name.get(digit, '') + '조'
171
+ elif i == 13:
172
+ name = digit2name.get(digit, '') + '십'
173
+ elif i == 14:
174
+ name = digit2name.get(digit, '') + '백'
175
+ elif i == 15:
176
+ name = digit2name.get(digit, '') + '천'
177
+ spelledout.append(name)
178
+ return ''.join(elem for elem in spelledout)
179
+
180
+
181
+ def number_to_hangul(text):
182
+ '''Reference https://github.com/Kyubyong/g2pK'''
183
+ tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
184
+ for token in tokens:
185
+ num, classifier = token
186
+ if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
187
+ spelledout = hangul_number(num, sino=False)
188
+ else:
189
+ spelledout = hangul_number(num, sino=True)
190
+ text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
191
+ # digit by digit for remaining digits
192
+ digits = '0123456789'
193
+ names = '영일이삼사오육칠팔구'
194
+ for d, n in zip(digits, names):
195
+ text = text.replace(d, n)
196
+ return text
197
+
198
+
199
+ def korean_to_lazy_ipa(text):
200
+ text = latin_to_hangul(text)
201
+ text = number_to_hangul(text)
202
+ text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
203
+ for regex, replacement in _ipa_to_lazy_ipa:
204
+ text = re.sub(regex, replacement, text)
205
+ return text
206
+
207
+
208
+ def korean_to_ipa(text):
209
+ text = korean_to_lazy_ipa(text)
210
+ return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
text/mandarin.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from pypinyin import lazy_pinyin, BOPOMOFO
5
+ import jieba
6
+ import cn2an
7
+ import logging
8
+
9
+ logging.getLogger('jieba').setLevel(logging.WARNING)
10
+ jieba.set_dictionary(r'./jieba/dict.txt')
11
+ jieba.initialize()
12
+
13
+
14
+ # List of (Latin alphabet, bopomofo) pairs:
15
+ _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
16
+ ('a', 'ㄟˉ'),
17
+ ('b', 'ㄅㄧˋ'),
18
+ ('c', 'ㄙㄧˉ'),
19
+ ('d', 'ㄉㄧˋ'),
20
+ ('e', 'ㄧˋ'),
21
+ ('f', 'ㄝˊㄈㄨˋ'),
22
+ ('g', 'ㄐㄧˋ'),
23
+ ('h', 'ㄝˇㄑㄩˋ'),
24
+ ('i', 'ㄞˋ'),
25
+ ('j', 'ㄐㄟˋ'),
26
+ ('k', 'ㄎㄟˋ'),
27
+ ('l', 'ㄝˊㄛˋ'),
28
+ ('m', 'ㄝˊㄇㄨˋ'),
29
+ ('n', 'ㄣˉ'),
30
+ ('o', 'ㄡˉ'),
31
+ ('p', 'ㄆㄧˉ'),
32
+ ('q', 'ㄎㄧㄡˉ'),
33
+ ('r', 'ㄚˋ'),
34
+ ('s', 'ㄝˊㄙˋ'),
35
+ ('t', 'ㄊㄧˋ'),
36
+ ('u', 'ㄧㄡˉ'),
37
+ ('v', 'ㄨㄧˉ'),
38
+ ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
39
+ ('x', 'ㄝˉㄎㄨˋㄙˋ'),
40
+ ('y', 'ㄨㄞˋ'),
41
+ ('z', 'ㄗㄟˋ')
42
+ ]]
43
+
44
+ # List of (bopomofo, romaji) pairs:
45
+ _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
46
+ ('ㄅㄛ', 'p⁼wo'),
47
+ ('ㄆㄛ', 'pʰwo'),
48
+ ('ㄇㄛ', 'mwo'),
49
+ ('ㄈㄛ', 'fwo'),
50
+ ('ㄅ', 'p⁼'),
51
+ ('ㄆ', 'pʰ'),
52
+ ('ㄇ', 'm'),
53
+ ('ㄈ', 'f'),
54
+ ('ㄉ', 't⁼'),
55
+ ('ㄊ', 'tʰ'),
56
+ ('ㄋ', 'n'),
57
+ ('ㄌ', 'l'),
58
+ ('ㄍ', 'k⁼'),
59
+ ('ㄎ', 'kʰ'),
60
+ ('ㄏ', 'h'),
61
+ ('ㄐ', 'ʧ⁼'),
62
+ ('ㄑ', 'ʧʰ'),
63
+ ('ㄒ', 'ʃ'),
64
+ ('ㄓ', 'ʦ`⁼'),
65
+ ('ㄔ', 'ʦ`ʰ'),
66
+ ('ㄕ', 's`'),
67
+ ('ㄖ', 'ɹ`'),
68
+ ('ㄗ', 'ʦ⁼'),
69
+ ('ㄘ', 'ʦʰ'),
70
+ ('ㄙ', 's'),
71
+ ('ㄚ', 'a'),
72
+ ('ㄛ', 'o'),
73
+ ('ㄜ', 'ə'),
74
+ ('ㄝ', 'e'),
75
+ ('ㄞ', 'ai'),
76
+ ('ㄟ', 'ei'),
77
+ ('ㄠ', 'au'),
78
+ ('ㄡ', 'ou'),
79
+ ('ㄧㄢ', 'yeNN'),
80
+ ('ㄢ', 'aNN'),
81
+ ('ㄧㄣ', 'iNN'),
82
+ ('ㄣ', 'əNN'),
83
+ ('ㄤ', 'aNg'),
84
+ ('ㄧㄥ', 'iNg'),
85
+ ('ㄨㄥ', 'uNg'),
86
+ ('ㄩㄥ', 'yuNg'),
87
+ ('ㄥ', 'əNg'),
88
+ ('ㄦ', 'əɻ'),
89
+ ('ㄧ', 'i'),
90
+ ('ㄨ', 'u'),
91
+ ('ㄩ', 'ɥ'),
92
+ ('ˉ', '→'),
93
+ ('ˊ', '↑'),
94
+ ('ˇ', '↓↑'),
95
+ ('ˋ', '↓'),
96
+ ('˙', ''),
97
+ (',', ','),
98
+ ('。', '.'),
99
+ ('!', '!'),
100
+ ('?', '?'),
101
+ ('—', '-')
102
+ ]]
103
+
104
+ # List of (romaji, ipa) pairs:
105
+ _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
106
+ ('ʃy', 'ʃ'),
107
+ ('ʧʰy', 'ʧʰ'),
108
+ ('ʧ⁼y', 'ʧ⁼'),
109
+ ('NN', 'n'),
110
+ ('Ng', 'ŋ'),
111
+ ('y', 'j'),
112
+ ('h', 'x')
113
+ ]]
114
+
115
+ # List of (bopomofo, ipa) pairs:
116
+ _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
117
+ ('ㄅㄛ', 'p⁼wo'),
118
+ ('ㄆㄛ', 'pʰwo'),
119
+ ('ㄇㄛ', 'mwo'),
120
+ ('ㄈㄛ', 'fwo'),
121
+ ('ㄅ', 'p⁼'),
122
+ ('ㄆ', 'pʰ'),
123
+ ('ㄇ', 'm'),
124
+ ('ㄈ', 'f'),
125
+ ('ㄉ', 't⁼'),
126
+ ('ㄊ', 'tʰ'),
127
+ ('ㄋ', 'n'),
128
+ ('ㄌ', 'l'),
129
+ ('ㄍ', 'k⁼'),
130
+ ('ㄎ', 'kʰ'),
131
+ ('ㄏ', 'x'),
132
+ ('ㄐ', 'tʃ⁼'),
133
+ ('ㄑ', 'tʃʰ'),
134
+ ('ㄒ', 'ʃ'),
135
+ ('ㄓ', 'ts`⁼'),
136
+ ('ㄔ', 'ts`ʰ'),
137
+ ('ㄕ', 's`'),
138
+ ('ㄖ', 'ɹ`'),
139
+ ('ㄗ', 'ts⁼'),
140
+ ('ㄘ', 'tsʰ'),
141
+ ('ㄙ', 's'),
142
+ ('ㄚ', 'a'),
143
+ ('ㄛ', 'o'),
144
+ ('ㄜ', 'ə'),
145
+ ('ㄝ', 'ɛ'),
146
+ ('ㄞ', 'aɪ'),
147
+ ('ㄟ', 'eɪ'),
148
+ ('ㄠ', 'ɑʊ'),
149
+ ('ㄡ', 'oʊ'),
150
+ ('ㄧㄢ', 'jɛn'),
151
+ ('ㄩㄢ', 'ɥæn'),
152
+ ('ㄢ', 'an'),
153
+ ('ㄧㄣ', 'in'),
154
+ ('ㄩㄣ', 'ɥn'),
155
+ ('ㄣ', 'ən'),
156
+ ('ㄤ', 'ɑŋ'),
157
+ ('ㄧㄥ', 'iŋ'),
158
+ ('ㄨㄥ', 'ʊŋ'),
159
+ ('ㄩㄥ', 'jʊŋ'),
160
+ ('ㄥ', 'əŋ'),
161
+ ('ㄦ', 'əɻ'),
162
+ ('ㄧ', 'i'),
163
+ ('ㄨ', 'u'),
164
+ ('ㄩ', 'ɥ'),
165
+ ('ˉ', '→'),
166
+ ('ˊ', '↑'),
167
+ ('ˇ', '↓↑'),
168
+ ('ˋ', '↓'),
169
+ ('˙', ''),
170
+ (',', ','),
171
+ ('。', '.'),
172
+ ('!', '!'),
173
+ ('?', '?'),
174
+ ('—', '-')
175
+ ]]
176
+
177
+ # List of (bopomofo, ipa2) pairs:
178
+ _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
179
+ ('ㄅㄛ', 'pwo'),
180
+ ('ㄆㄛ', 'pʰwo'),
181
+ ('ㄇㄛ', 'mwo'),
182
+ ('ㄈㄛ', 'fwo'),
183
+ ('ㄅ', 'p'),
184
+ ('ㄆ', 'pʰ'),
185
+ ('ㄇ', 'm'),
186
+ ('ㄈ', 'f'),
187
+ ('ㄉ', 't'),
188
+ ('ㄊ', 'tʰ'),
189
+ ('ㄋ', 'n'),
190
+ ('ㄌ', 'l'),
191
+ ('ㄍ', 'k'),
192
+ ('ㄎ', 'kʰ'),
193
+ ('ㄏ', 'h'),
194
+ ('ㄐ', 'tɕ'),
195
+ ('ㄑ', 'tɕʰ'),
196
+ ('ㄒ', 'ɕ'),
197
+ ('ㄓ', 'tʂ'),
198
+ ('ㄔ', 'tʂʰ'),
199
+ ('ㄕ', 'ʂ'),
200
+ ('ㄖ', 'ɻ'),
201
+ ('ㄗ', 'ts'),
202
+ ('ㄘ', 'tsʰ'),
203
+ ('ㄙ', 's'),
204
+ ('ㄚ', 'a'),
205
+ ('ㄛ', 'o'),
206
+ ('ㄜ', 'ɤ'),
207
+ ('ㄝ', 'ɛ'),
208
+ ('ㄞ', 'aɪ'),
209
+ ('ㄟ', 'eɪ'),
210
+ ('ㄠ', 'ɑʊ'),
211
+ ('ㄡ', 'oʊ'),
212
+ ('ㄧㄢ', 'jɛn'),
213
+ ('ㄩㄢ', 'yæn'),
214
+ ('ㄢ', 'an'),
215
+ ('ㄧㄣ', 'in'),
216
+ ('ㄩㄣ', 'yn'),
217
+ ('ㄣ', 'ən'),
218
+ ('ㄤ', 'ɑŋ'),
219
+ ('ㄧㄥ', 'iŋ'),
220
+ ('ㄨㄥ', 'ʊŋ'),
221
+ ('ㄩㄥ', 'jʊŋ'),
222
+ ('ㄥ', 'ɤŋ'),
223
+ ('ㄦ', 'əɻ'),
224
+ ('ㄧ', 'i'),
225
+ ('ㄨ', 'u'),
226
+ ('ㄩ', 'y'),
227
+ ('ˉ', '˥'),
228
+ ('ˊ', '˧˥'),
229
+ ('ˇ', '˨˩˦'),
230
+ ('ˋ', '˥˩'),
231
+ ('˙', ''),
232
+ (',', ','),
233
+ ('。', '.'),
234
+ ('!', '!'),
235
+ ('?', '?'),
236
+ ('—', '-')
237
+ ]]
238
+
239
+
240
+ def number_to_chinese(text):
241
+ numbers = re.findall(r'\d+(?:\.?\d+)?', text)
242
+ for number in numbers:
243
+ text = text.replace(number, cn2an.an2cn(number), 1)
244
+ return text
245
+
246
+
247
+ def chinese_to_bopomofo(text):
248
+ text = text.replace('、', ',').replace(';', ',').replace(':', ',')
249
+ words = jieba.lcut(text, cut_all=False)
250
+ text = ''
251
+ for word in words:
252
+ bopomofos = lazy_pinyin(word, BOPOMOFO)
253
+ if not re.search('[\u4e00-\u9fff]', word):
254
+ text += word
255
+ continue
256
+ for i in range(len(bopomofos)):
257
+ bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
258
+ if text != '':
259
+ text += ' '
260
+ text += ''.join(bopomofos)
261
+ return text
262
+
263
+
264
+ def latin_to_bopomofo(text):
265
+ for regex, replacement in _latin_to_bopomofo:
266
+ text = re.sub(regex, replacement, text)
267
+ return text
268
+
269
+
270
+ def bopomofo_to_romaji(text):
271
+ for regex, replacement in _bopomofo_to_romaji:
272
+ text = re.sub(regex, replacement, text)
273
+ return text
274
+
275
+
276
+ def bopomofo_to_ipa(text):
277
+ for regex, replacement in _bopomofo_to_ipa:
278
+ text = re.sub(regex, replacement, text)
279
+ return text
280
+
281
+
282
+ def bopomofo_to_ipa2(text):
283
+ for regex, replacement in _bopomofo_to_ipa2:
284
+ text = re.sub(regex, replacement, text)
285
+ return text
286
+
287
+
288
+ def chinese_to_romaji(text):
289
+ text = number_to_chinese(text)
290
+ text = chinese_to_bopomofo(text)
291
+ text = latin_to_bopomofo(text)
292
+ text = bopomofo_to_romaji(text)
293
+ text = re.sub('i([aoe])', r'y\1', text)
294
+ text = re.sub('u([aoəe])', r'w\1', text)
295
+ text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
296
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
297
+ text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
298
+ return text
299
+
300
+
301
+ def chinese_to_lazy_ipa(text):
302
+ text = chinese_to_romaji(text)
303
+ for regex, replacement in _romaji_to_ipa:
304
+ text = re.sub(regex, replacement, text)
305
+ return text
306
+
307
+
308
+ def chinese_to_ipa(text):
309
+ text = number_to_chinese(text)
310
+ text = chinese_to_bopomofo(text)
311
+ text = latin_to_bopomofo(text)
312
+ text = bopomofo_to_ipa(text)
313
+ text = re.sub('i([aoe])', r'j\1', text)
314
+ text = re.sub('u([aoəe])', r'w\1', text)
315
+ text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
316
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
317
+ text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
318
+ return text
319
+
320
+
321
+ def chinese_to_ipa2(text):
322
+ text = number_to_chinese(text)
323
+ text = chinese_to_bopomofo(text)
324
+ text = latin_to_bopomofo(text)
325
+ text = bopomofo_to_ipa2(text)
326
+ text = re.sub(r'i([aoe])', r'j\1', text)
327
+ text = re.sub(r'u([aoəe])', r'w\1', text)
328
+ text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
329
+ text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
330
+ return text
text/ngu_dialect.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import opencc
3
+
4
+
5
+ dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou',
6
+ 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing',
7
+ 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang',
8
+ 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan',
9
+ 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen',
10
+ 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'}
11
+
12
+ converters = {}
13
+
14
+ for dialect in dialects.values():
15
+ try:
16
+ converters[dialect] = opencc.OpenCC(dialect)
17
+ except:
18
+ pass
19
+
20
+
21
+ def ngu_dialect_to_ipa(text, dialect):
22
+ dialect = dialects[dialect]
23
+ text = converters[dialect].convert(text).replace('-','').replace('$',' ')
24
+ text = re.sub(r'[、;:]', ',', text)
25
+ text = re.sub(r'\s*,\s*', ', ', text)
26
+ text = re.sub(r'\s*。\s*', '. ', text)
27
+ text = re.sub(r'\s*?\s*', '? ', text)
28
+ text = re.sub(r'\s*!\s*', '! ', text)
29
+ text = re.sub(r'\s*$', '', text)
30
+ return text
text/sanskrit.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from indic_transliteration import sanscript
3
+
4
+
5
+ # List of (iast, ipa) pairs:
6
+ _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
7
+ ('a', 'ə'),
8
+ ('ā', 'aː'),
9
+ ('ī', 'iː'),
10
+ ('ū', 'uː'),
11
+ ('ṛ', 'ɹ`'),
12
+ ('ṝ', 'ɹ`ː'),
13
+ ('ḷ', 'l`'),
14
+ ('ḹ', 'l`ː'),
15
+ ('e', 'eː'),
16
+ ('o', 'oː'),
17
+ ('k', 'k⁼'),
18
+ ('k⁼h', 'kʰ'),
19
+ ('g', 'g⁼'),
20
+ ('g⁼h', 'gʰ'),
21
+ ('ṅ', 'ŋ'),
22
+ ('c', 'ʧ⁼'),
23
+ ('ʧ⁼h', 'ʧʰ'),
24
+ ('j', 'ʥ⁼'),
25
+ ('ʥ⁼h', 'ʥʰ'),
26
+ ('ñ', 'n^'),
27
+ ('ṭ', 't`⁼'),
28
+ ('t`⁼h', 't`ʰ'),
29
+ ('ḍ', 'd`⁼'),
30
+ ('d`⁼h', 'd`ʰ'),
31
+ ('ṇ', 'n`'),
32
+ ('t', 't⁼'),
33
+ ('t⁼h', 'tʰ'),
34
+ ('d', 'd⁼'),
35
+ ('d⁼h', 'dʰ'),
36
+ ('p', 'p⁼'),
37
+ ('p⁼h', 'pʰ'),
38
+ ('b', 'b⁼'),
39
+ ('b⁼h', 'bʰ'),
40
+ ('y', 'j'),
41
+ ('ś', 'ʃ'),
42
+ ('ṣ', 's`'),
43
+ ('r', 'ɾ'),
44
+ ('l̤', 'l`'),
45
+ ('h', 'ɦ'),
46
+ ("'", ''),
47
+ ('~', '^'),
48
+ ('ṃ', '^')
49
+ ]]
50
+
51
+
52
+ def devanagari_to_ipa(text):
53
+ text = text.replace('ॐ', 'ओम्')
54
+ text = re.sub(r'\s*।\s*$', '.', text)
55
+ text = re.sub(r'\s*।\s*', ', ', text)
56
+ text = re.sub(r'\s*॥', '.', text)
57
+ text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST)
58
+ for regex, replacement in _iast_to_ipa:
59
+ text = re.sub(regex, replacement, text)
60
+ text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0)
61
+ [:-1]+'h'+x.group(1)+'*', text)
62
+ return text
text/shanghainese.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import cn2an
3
+ import opencc
4
+
5
+
6
+ converter = opencc.OpenCC('zaonhe')
7
+
8
+ # List of (Latin alphabet, ipa) pairs:
9
+ _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
10
+ ('A', 'ᴇ'),
11
+ ('B', 'bi'),
12
+ ('C', 'si'),
13
+ ('D', 'di'),
14
+ ('E', 'i'),
15
+ ('F', 'ᴇf'),
16
+ ('G', 'dʑi'),
17
+ ('H', 'ᴇtɕʰ'),
18
+ ('I', 'ᴀi'),
19
+ ('J', 'dʑᴇ'),
20
+ ('K', 'kʰᴇ'),
21
+ ('L', 'ᴇl'),
22
+ ('M', 'ᴇm'),
23
+ ('N', 'ᴇn'),
24
+ ('O', 'o'),
25
+ ('P', 'pʰi'),
26
+ ('Q', 'kʰiu'),
27
+ ('R', 'ᴀl'),
28
+ ('S', 'ᴇs'),
29
+ ('T', 'tʰi'),
30
+ ('U', 'ɦiu'),
31
+ ('V', 'vi'),
32
+ ('W', 'dᴀbɤliu'),
33
+ ('X', 'ᴇks'),
34
+ ('Y', 'uᴀi'),
35
+ ('Z', 'zᴇ')
36
+ ]]
37
+
38
+
39
+ def _number_to_shanghainese(num):
40
+ num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两')
41
+ return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num)
42
+
43
+
44
+ def number_to_shanghainese(text):
45
+ return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text)
46
+
47
+
48
+ def latin_to_ipa(text):
49
+ for regex, replacement in _latin_to_ipa:
50
+ text = re.sub(regex, replacement, text)
51
+ return text
52
+
53
+
54
+ def shanghainese_to_ipa(text):
55
+ text = number_to_shanghainese(text.upper())
56
+ text = converter.convert(text).replace('-','').replace('$',' ')
57
+ text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
58
+ text = re.sub(r'[、;:]', ',', text)
59
+ text = re.sub(r'\s*,\s*', ', ', text)
60
+ text = re.sub(r'\s*。\s*', '. ', text)
61
+ text = re.sub(r'\s*?\s*', '? ', text)
62
+ text = re.sub(r'\s*!\s*', '! ', text)
63
+ text = re.sub(r'\s*$', '', text)
64
+ return text
text/thai.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from num_thai.thainumbers import NumThai
3
+
4
+
5
+ num = NumThai()
6
+
7
+ # List of (Latin alphabet, Thai) pairs:
8
+ _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
9
+ ('a', 'เอ'),
10
+ ('b','บี'),
11
+ ('c','ซี'),
12
+ ('d','ดี'),
13
+ ('e','อี'),
14
+ ('f','เอฟ'),
15
+ ('g','จี'),
16
+ ('h','เอช'),
17
+ ('i','ไอ'),
18
+ ('j','เจ'),
19
+ ('k','เค'),
20
+ ('l','แอล'),
21
+ ('m','เอ็ม'),
22
+ ('n','เอ็น'),
23
+ ('o','โอ'),
24
+ ('p','พี'),
25
+ ('q','คิว'),
26
+ ('r','แอร์'),
27
+ ('s','เอส'),
28
+ ('t','ที'),
29
+ ('u','ยู'),
30
+ ('v','วี'),
31
+ ('w','ดับเบิลยู'),
32
+ ('x','เอ็กซ์'),
33
+ ('y','วาย'),
34
+ ('z','ซี')
35
+ ]]
36
+
37
+
38
+ def num_to_thai(text):
39
+ return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text)
40
+
41
+ def latin_to_thai(text):
42
+ for regex, replacement in _latin_to_thai:
43
+ text = re.sub(regex, replacement, text)
44
+ return text
transforms.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(inputs,
13
+ unnormalized_widths,
14
+ unnormalized_heights,
15
+ unnormalized_derivatives,
16
+ inverse=False,
17
+ tails=None,
18
+ tail_bound=1.,
19
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
22
+
23
+ if tails is None:
24
+ spline_fn = rational_quadratic_spline
25
+ spline_kwargs = {}
26
+ else:
27
+ spline_fn = unconstrained_rational_quadratic_spline
28
+ spline_kwargs = {
29
+ 'tails': tails,
30
+ 'tail_bound': tail_bound
31
+ }
32
+
33
+ outputs, logabsdet = spline_fn(
34
+ inputs=inputs,
35
+ unnormalized_widths=unnormalized_widths,
36
+ unnormalized_heights=unnormalized_heights,
37
+ unnormalized_derivatives=unnormalized_derivatives,
38
+ inverse=inverse,
39
+ min_bin_width=min_bin_width,
40
+ min_bin_height=min_bin_height,
41
+ min_derivative=min_derivative,
42
+ **spline_kwargs
43
+ )
44
+ return outputs, logabsdet
45
+
46
+
47
+ def searchsorted(bin_locations, inputs, eps=1e-6):
48
+ bin_locations[..., -1] += eps
49
+ return torch.sum(
50
+ inputs[..., None] >= bin_locations,
51
+ dim=-1
52
+ ) - 1
53
+
54
+
55
+ def unconstrained_rational_quadratic_spline(inputs,
56
+ unnormalized_widths,
57
+ unnormalized_heights,
58
+ unnormalized_derivatives,
59
+ inverse=False,
60
+ tails='linear',
61
+ tail_bound=1.,
62
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
65
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66
+ outside_interval_mask = ~inside_interval_mask
67
+
68
+ outputs = torch.zeros_like(inputs)
69
+ logabsdet = torch.zeros_like(inputs)
70
+
71
+ if tails == 'linear':
72
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73
+ constant = np.log(np.exp(1 - min_derivative) - 1)
74
+ unnormalized_derivatives[..., 0] = constant
75
+ unnormalized_derivatives[..., -1] = constant
76
+
77
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
78
+ logabsdet[outside_interval_mask] = 0
79
+ else:
80
+ raise RuntimeError('{} tails are not implemented.'.format(tails))
81
+
82
+ outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89
+ min_bin_width=min_bin_width,
90
+ min_bin_height=min_bin_height,
91
+ min_derivative=min_derivative
92
+ )
93
+
94
+ return outputs, logabsdet
95
+
96
+ def rational_quadratic_spline(inputs,
97
+ unnormalized_widths,
98
+ unnormalized_heights,
99
+ unnormalized_derivatives,
100
+ inverse=False,
101
+ left=0., right=1., bottom=0., top=1.,
102
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
105
+ if torch.min(inputs) < left or torch.max(inputs) > right:
106
+ raise ValueError('Input to a transform is not within its domain')
107
+
108
+ num_bins = unnormalized_widths.shape[-1]
109
+
110
+ if min_bin_width * num_bins > 1.0:
111
+ raise ValueError('Minimal bin width too large for the number of bins')
112
+ if min_bin_height * num_bins > 1.0:
113
+ raise ValueError('Minimal bin height too large for the number of bins')
114
+
115
+ widths = F.softmax(unnormalized_widths, dim=-1)
116
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117
+ cumwidths = torch.cumsum(widths, dim=-1)
118
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119
+ cumwidths = (right - left) * cumwidths + left
120
+ cumwidths[..., 0] = left
121
+ cumwidths[..., -1] = right
122
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123
+
124
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125
+
126
+ heights = F.softmax(unnormalized_heights, dim=-1)
127
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128
+ cumheights = torch.cumsum(heights, dim=-1)
129
+ cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130
+ cumheights = (top - bottom) * cumheights + bottom
131
+ cumheights[..., 0] = bottom
132
+ cumheights[..., -1] = top
133
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
134
+
135
+ if inverse:
136
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
137
+ else:
138
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
139
+
140
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142
+
143
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144
+ delta = heights / widths
145
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
146
+
147
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149
+
150
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
151
+
152
+ if inverse:
153
+ a = (((inputs - input_cumheights) * (input_derivatives
154
+ + input_derivatives_plus_one
155
+ - 2 * input_delta)
156
+ + input_heights * (input_delta - input_derivatives)))
157
+ b = (input_heights * input_derivatives
158
+ - (inputs - input_cumheights) * (input_derivatives
159
+ + input_derivatives_plus_one
160
+ - 2 * input_delta))
161
+ c = - input_delta * (inputs - input_cumheights)
162
+
163
+ discriminant = b.pow(2) - 4 * a * c
164
+ assert (discriminant >= 0).all()
165
+
166
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
167
+ outputs = root * input_bin_widths + input_cumwidths
168
+
169
+ theta_one_minus_theta = root * (1 - root)
170
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171
+ * theta_one_minus_theta)
172
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173
+ + 2 * input_delta * theta_one_minus_theta
174
+ + input_derivatives * (1 - root).pow(2))
175
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176
+
177
+ return outputs, -logabsdet
178
+ else:
179
+ theta = (inputs - input_cumwidths) / input_bin_widths
180
+ theta_one_minus_theta = theta * (1 - theta)
181
+
182
+ numerator = input_heights * (input_delta * theta.pow(2)
183
+ + input_derivatives * theta_one_minus_theta)
184
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
+ * theta_one_minus_theta)
186
+ outputs = input_cumheights + numerator / denominator
187
+
188
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189
+ + 2 * input_delta * theta_one_minus_theta
190
+ + input_derivatives * (1 - theta).pow(2))
191
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192
+
193
+ return outputs, logabsdet
utils_vits.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from json import loads
3
+ from torch import load, FloatTensor
4
+ from numpy import float32
5
+ import librosa
6
+
7
+
8
+ class HParams():
9
+ def __init__(self, **kwargs):
10
+ for k, v in kwargs.items():
11
+ if type(v) == dict:
12
+ v = HParams(**v)
13
+ self[k] = v
14
+
15
+ def keys(self):
16
+ return self.__dict__.keys()
17
+
18
+ def items(self):
19
+ return self.__dict__.items()
20
+
21
+ def values(self):
22
+ return self.__dict__.values()
23
+
24
+ def __len__(self):
25
+ return len(self.__dict__)
26
+
27
+ def __getitem__(self, key):
28
+ return getattr(self, key)
29
+
30
+ def __setitem__(self, key, value):
31
+ return setattr(self, key, value)
32
+
33
+ def __contains__(self, key):
34
+ return key in self.__dict__
35
+
36
+ def __repr__(self):
37
+ return self.__dict__.__repr__()
38
+
39
+
40
+ def load_checkpoint(checkpoint_path, model):
41
+ checkpoint_dict = load(checkpoint_path, map_location='cpu')
42
+ iteration = checkpoint_dict['iteration']
43
+ saved_state_dict = checkpoint_dict['model']
44
+ if hasattr(model, 'module'):
45
+ state_dict = model.module.state_dict()
46
+ else:
47
+ state_dict = model.state_dict()
48
+ new_state_dict= {}
49
+ for k, v in state_dict.items():
50
+ try:
51
+ new_state_dict[k] = saved_state_dict[k]
52
+ except:
53
+ logging.info("%s is not in the checkpoint" % k)
54
+ new_state_dict[k] = v
55
+ if hasattr(model, 'module'):
56
+ model.module.load_state_dict(new_state_dict)
57
+ else:
58
+ model.load_state_dict(new_state_dict)
59
+ logging.info("Loaded checkpoint '{}' (iteration {})" .format(
60
+ checkpoint_path, iteration))
61
+ return
62
+
63
+
64
+ def get_hparams_from_file(config_path):
65
+ with open(config_path, "r") as f:
66
+ data = f.read()
67
+ config = loads(data)
68
+
69
+ hparams = HParams(**config)
70
+ return hparams
71
+
72
+
73
+ def load_audio_to_torch(full_path, target_sampling_rate):
74
+ audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
75
+ return FloatTensor(audio.astype(float32))
visual_chatgpt.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import random
4
+ import torch
5
+ import cv2
6
+ import re
7
+ import uuid
8
+ from PIL import Image
9
+ import numpy as np
10
+ import argparse
11
+
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
13
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
14
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
15
+
16
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
17
+ from diffusers import EulerAncestralDiscreteScheduler
18
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
19
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
20
+
21
+ from langchain.agents.initialize import initialize_agent
22
+ from langchain.agents.tools import Tool
23
+ from langchain.chains.conversation.memory import ConversationBufferMemory
24
+ from langchain.llms.openai import OpenAI
25
+
26
+ VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
27
+
28
+ Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
29
+
30
+ Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
31
+
32
+ Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
33
+
34
+
35
+ TOOLS:
36
+ ------
37
+
38
+ Visual ChatGPT has access to the following tools:"""
39
+
40
+ VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
41
+
42
+ ```
43
+ Thought: Do I need to use a tool? Yes
44
+ Action: the action to take, should be one of [{tool_names}]
45
+ Action Input: the input to the action
46
+ Observation: the result of the action
47
+ ```
48
+
49
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
50
+
51
+ ```
52
+ Thought: Do I need to use a tool? No
53
+ {ai_prefix}: [your response here]
54
+ ```
55
+ """
56
+
57
+ VISUAL_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if it does not exist.
58
+ You will remember to provide the image file name loyally if it's provided in the last tool observation.
59
+
60
+ Begin!
61
+
62
+ Previous conversation history:
63
+ {chat_history}
64
+
65
+ New input: {input}
66
+ Since Visual ChatGPT is a text language model, Visual ChatGPT must use tools to observe images rather than imagination.
67
+ The thoughts and observations are only visible for Visual ChatGPT, Visual ChatGPT should remember to repeat important information in the final response for Human.
68
+ Thought: Do I need to use a tool? {agent_scratchpad}"""
69
+
70
+ os.makedirs('image', exist_ok=True)
71
+
72
+
73
+ def seed_everything(seed):
74
+ random.seed(seed)
75
+ np.random.seed(seed)
76
+ torch.manual_seed(seed)
77
+ torch.cuda.manual_seed_all(seed)
78
+ return seed
79
+
80
+
81
+ def prompts(name, description):
82
+ def decorator(func):
83
+ func.name = name
84
+ func.description = description
85
+ return func
86
+
87
+ return decorator
88
+
89
+
90
+ def cut_dialogue_history(history_memory, keep_last_n_words=500):
91
+ tokens = history_memory.split()
92
+ n_tokens = len(tokens)
93
+ print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
94
+ if n_tokens < keep_last_n_words:
95
+ return history_memory
96
+ else:
97
+ paragraphs = history_memory.split('\n')
98
+ last_n_tokens = n_tokens
99
+ while last_n_tokens >= keep_last_n_words:
100
+ last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
101
+ paragraphs = paragraphs[1:]
102
+ return '\n' + '\n'.join(paragraphs)
103
+
104
+
105
+ def get_new_image_name(org_img_name, func_name="update"):
106
+ head_tail = os.path.split(org_img_name)
107
+ head = head_tail[0]
108
+ tail = head_tail[1]
109
+ name_split = tail.split('.')[0].split('_')
110
+ this_new_uuid = str(uuid.uuid4())[0:4]
111
+ if len(name_split) == 1:
112
+ most_org_file_name = name_split[0]
113
+ recent_prev_file_name = name_split[0]
114
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
115
+ else:
116
+ assert len(name_split) == 4
117
+ most_org_file_name = name_split[3]
118
+ recent_prev_file_name = name_split[0]
119
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
120
+ return os.path.join(head, new_file_name)
121
+
122
+
123
+ class MaskFormer:
124
+ def __init__(self, device):
125
+ print("Initializing MaskFormer to %s" % device)
126
+ self.device = device
127
+ self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
128
+ self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
129
+
130
+ def inference(self, image_path, text):
131
+ threshold = 0.5
132
+ min_area = 0.02
133
+ padding = 20
134
+ original_image = Image.open(image_path)
135
+ image = original_image.resize((512, 512))
136
+ inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
137
+ with torch.no_grad():
138
+ outputs = self.model(**inputs)
139
+ mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
140
+ area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
141
+ if area_ratio < min_area:
142
+ return None
143
+ true_indices = np.argwhere(mask)
144
+ mask_array = np.zeros_like(mask, dtype=bool)
145
+ for idx in true_indices:
146
+ padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
147
+ mask_array[padded_slice] = True
148
+ visual_mask = (mask_array * 255).astype(np.uint8)
149
+ image_mask = Image.fromarray(visual_mask)
150
+ return image_mask.resize(original_image.size)
151
+
152
+
153
+ class ImageEditing:
154
+ def __init__(self, device):
155
+ print("Initializing ImageEditing to %s" % device)
156
+ self.device = device
157
+ self.mask_former = MaskFormer(device=self.device)
158
+ self.revision = 'fp16' if 'cuda' in device else None
159
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
160
+ self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
161
+ "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
162
+
163
+ @prompts(name="Remove Something From The Photo",
164
+ description="useful when you want to remove and object or something from the photo "
165
+ "from its description or location. "
166
+ "The input to this tool should be a comma seperated string of two, "
167
+ "representing the image_path and the object need to be removed. ")
168
+ def inference_remove(self, inputs):
169
+ image_path, to_be_removed_txt = inputs.split(",")
170
+ return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
171
+
172
+ @prompts(name="Replace Something From The Photo",
173
+ description="useful when you want to replace an object from the object description or "
174
+ "location with another object from its description. "
175
+ "The input to this tool should be a comma seperated string of three, "
176
+ "representing the image_path, the object to be replaced, the object to be replaced with ")
177
+ def inference_replace(self, inputs):
178
+ image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
179
+ original_image = Image.open(image_path)
180
+ original_size = original_image.size
181
+ mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
182
+ updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)),
183
+ mask_image=mask_image.resize((512, 512))).images[0]
184
+ updated_image_path = get_new_image_name(image_path, func_name="replace-something")
185
+ updated_image = updated_image.resize(original_size)
186
+ updated_image.save(updated_image_path)
187
+ print(
188
+ f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
189
+ f"Output Image: {updated_image_path}")
190
+ return updated_image_path
191
+
192
+
193
+ class InstructPix2Pix:
194
+ def __init__(self, device):
195
+ print("Initializing InstructPix2Pix to %s" % device)
196
+ self.device = device
197
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
198
+ self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
199
+ safety_checker=None,
200
+ torch_dtype=self.torch_dtype).to(device)
201
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
202
+
203
+ @prompts(name="Instruct Image Using Text",
204
+ description="useful when you want to the style of the image to be like the text. "
205
+ "like: make it look like a painting. or make it like a robot. "
206
+ "The input to this tool should be a comma seperated string of two, "
207
+ "representing the image_path and the text. ")
208
+ def inference(self, inputs):
209
+ """Change style of image."""
210
+ print("===>Starting InstructPix2Pix Inference")
211
+ image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
212
+ original_image = Image.open(image_path)
213
+ image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
214
+ updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
215
+ image.save(updated_image_path)
216
+ print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
217
+ f"Output Image: {updated_image_path}")
218
+ return updated_image_path
219
+
220
+
221
+ class Text2Image:
222
+ def __init__(self, device):
223
+ print("Initializing Text2Image to %s" % device)
224
+ self.device = device
225
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
226
+ self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
227
+ torch_dtype=self.torch_dtype)
228
+ self.pipe.to(device)
229
+ self.a_prompt = 'best quality, extremely detailed'
230
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
231
+ 'fewer digits, cropped, worst quality, low quality'
232
+
233
+ @prompts(name="Generate Image From User Input Text",
234
+ description="useful when you want to generate an image from a user input text and save it to a file. "
235
+ "like: generate an image of an object or something, or generate an image that includes some objects. "
236
+ "The input to this tool should be a string, representing the text used to generate image. ")
237
+ def inference(self, text):
238
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
239
+ prompt = text + ', ' + self.a_prompt
240
+ image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
241
+ image.save(image_filename)
242
+ print(
243
+ f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
244
+ return image_filename
245
+
246
+
247
+ class ImageCaptioning:
248
+ def __init__(self, device):
249
+ print("Initializing ImageCaptioning to %s" % device)
250
+ self.device = device
251
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
252
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
253
+ self.model = BlipForConditionalGeneration.from_pretrained(
254
+ "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
255
+
256
+ @prompts(name="Get Photo Description",
257
+ description="useful when you want to know what is inside the photo. receives image_path as input. "
258
+ "The input to this tool should be a string, representing the image_path. ")
259
+ def inference(self, image_path):
260
+ inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
261
+ out = self.model.generate(**inputs)
262
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
263
+ print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
264
+ return captions
265
+
266
+
267
+ class Image2Canny:
268
+ def __init__(self, device):
269
+ print("Initializing Image2Canny")
270
+ self.low_threshold = 100
271
+ self.high_threshold = 200
272
+
273
+ @prompts(name="Edge Detection On Image",
274
+ description="useful when you want to detect the edge of the image. "
275
+ "like: detect the edges of this image, or canny detection on image, "
276
+ "or perform edge detection on this image, or detect the canny image of this image. "
277
+ "The input to this tool should be a string, representing the image_path")
278
+ def inference(self, inputs):
279
+ image = Image.open(inputs)
280
+ image = np.array(image)
281
+ canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
282
+ canny = canny[:, :, None]
283
+ canny = np.concatenate([canny, canny, canny], axis=2)
284
+ canny = Image.fromarray(canny)
285
+ updated_image_path = get_new_image_name(inputs, func_name="edge")
286
+ canny.save(updated_image_path)
287
+ print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}")
288
+ return updated_image_path
289
+
290
+
291
+ class CannyText2Image:
292
+ def __init__(self, device):
293
+ print("Initializing CannyText2Image to %s" % device)
294
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
295
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny",
296
+ torch_dtype=self.torch_dtype)
297
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
298
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
299
+ torch_dtype=self.torch_dtype)
300
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
301
+ self.pipe.to(device)
302
+ self.seed = -1
303
+ self.a_prompt = 'best quality, extremely detailed'
304
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
305
+ 'fewer digits, cropped, worst quality, low quality'
306
+
307
+ @prompts(name="Generate Image Condition On Canny Image",
308
+ description="useful when you want to generate a new real image from both the user desciption and a canny image."
309
+ " like: generate a real image of a object or something from this canny image,"
310
+ " or generate a new real image of a object or something from this edge image. "
311
+ "The input to this tool should be a comma seperated string of two, "
312
+ "representing the image_path and the user description. ")
313
+ def inference(self, inputs):
314
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
315
+ image = Image.open(image_path)
316
+ self.seed = random.randint(0, 65535)
317
+ seed_everything(self.seed)
318
+ prompt = instruct_text + ', ' + self.a_prompt
319
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
320
+ guidance_scale=9.0).images[0]
321
+ updated_image_path = get_new_image_name(image_path, func_name="canny2image")
322
+ image.save(updated_image_path)
323
+ print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, "
324
+ f"Output Text: {updated_image_path}")
325
+ return updated_image_path
326
+
327
+
328
+ class Image2Line:
329
+ def __init__(self, device):
330
+ print("Initializing Image2Line")
331
+ self.detector = MLSDdetector.from_pretrained('lllyasviel/ControlNet')
332
+
333
+ @prompts(name="Line Detection On Image",
334
+ description="useful when you want to detect the straight line of the image. "
335
+ "like: detect the straight lines of this image, or straight line detection on image, "
336
+ "or peform straight line detection on this image, or detect the straight line image of this image. "
337
+ "The input to this tool should be a string, representing the image_path")
338
+ def inference(self, inputs):
339
+ image = Image.open(inputs)
340
+ mlsd = self.detector(image)
341
+ updated_image_path = get_new_image_name(inputs, func_name="line-of")
342
+ mlsd.save(updated_image_path)
343
+ print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}")
344
+ return updated_image_path
345
+
346
+
347
+ class LineText2Image:
348
+ def __init__(self, device):
349
+ print("Initializing LineText2Image to %s" % device)
350
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
351
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd",
352
+ torch_dtype=self.torch_dtype)
353
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
354
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
355
+ torch_dtype=self.torch_dtype
356
+ )
357
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
358
+ self.pipe.to(device)
359
+ self.seed = -1
360
+ self.a_prompt = 'best quality, extremely detailed'
361
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
362
+ 'fewer digits, cropped, worst quality, low quality'
363
+
364
+ @prompts(name="Generate Image Condition On Line Image",
365
+ description="useful when you want to generate a new real image from both the user desciption "
366
+ "and a straight line image. "
367
+ "like: generate a real image of a object or something from this straight line image, "
368
+ "or generate a new real image of a object or something from this straight lines. "
369
+ "The input to this tool should be a comma seperated string of two, "
370
+ "representing the image_path and the user description. ")
371
+ def inference(self, inputs):
372
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
373
+ image = Image.open(image_path)
374
+ self.seed = random.randint(0, 65535)
375
+ seed_everything(self.seed)
376
+ prompt = instruct_text + ', ' + self.a_prompt
377
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
378
+ guidance_scale=9.0).images[0]
379
+ updated_image_path = get_new_image_name(image_path, func_name="line2image")
380
+ image.save(updated_image_path)
381
+ print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, "
382
+ f"Output Text: {updated_image_path}")
383
+ return updated_image_path
384
+
385
+
386
+ class Image2Hed:
387
+ def __init__(self, device):
388
+ print("Initializing Image2Hed")
389
+ self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
390
+
391
+ @prompts(name="Hed Detection On Image",
392
+ description="useful when you want to detect the soft hed boundary of the image. "
393
+ "like: detect the soft hed boundary of this image, or hed boundary detection on image, "
394
+ "or peform hed boundary detection on this image, or detect soft hed boundary image of this image. "
395
+ "The input to this tool should be a string, representing the image_path")
396
+ def inference(self, inputs):
397
+ image = Image.open(inputs)
398
+ hed = self.detector(image)
399
+ updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
400
+ hed.save(updated_image_path)
401
+ print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}")
402
+ return updated_image_path
403
+
404
+
405
+ class HedText2Image:
406
+ def __init__(self, device):
407
+ print("Initializing HedText2Image to %s" % device)
408
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
409
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed",
410
+ torch_dtype=self.torch_dtype)
411
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
412
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
413
+ torch_dtype=self.torch_dtype
414
+ )
415
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
416
+ self.pipe.to(device)
417
+ self.seed = -1
418
+ self.a_prompt = 'best quality, extremely detailed'
419
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
420
+ 'fewer digits, cropped, worst quality, low quality'
421
+
422
+ @prompts(name="Generate Image Condition On Soft Hed Boundary Image",
423
+ description="useful when you want to generate a new real image from both the user desciption "
424
+ "and a soft hed boundary image. "
425
+ "like: generate a real image of a object or something from this soft hed boundary image, "
426
+ "or generate a new real image of a object or something from this hed boundary. "
427
+ "The input to this tool should be a comma seperated string of two, "
428
+ "representing the image_path and the user description")
429
+ def inference(self, inputs):
430
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
431
+ image = Image.open(image_path)
432
+ self.seed = random.randint(0, 65535)
433
+ seed_everything(self.seed)
434
+ prompt = instruct_text + ', ' + self.a_prompt
435
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
436
+ guidance_scale=9.0).images[0]
437
+ updated_image_path = get_new_image_name(image_path, func_name="hed2image")
438
+ image.save(updated_image_path)
439
+ print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, "
440
+ f"Output Image: {updated_image_path}")
441
+ return updated_image_path
442
+
443
+
444
+ class Image2Scribble:
445
+ def __init__(self, device):
446
+ print("Initializing Image2Scribble")
447
+ self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
448
+
449
+ @prompts(name="Sketch Detection On Image",
450
+ description="useful when you want to generate a scribble of the image. "
451
+ "like: generate a scribble of this image, or generate a sketch from this image, "
452
+ "detect the sketch from this image. "
453
+ "The input to this tool should be a string, representing the image_path")
454
+ def inference(self, inputs):
455
+ image = Image.open(inputs)
456
+ scribble = self.detector(image, scribble=True)
457
+ updated_image_path = get_new_image_name(inputs, func_name="scribble")
458
+ scribble.save(updated_image_path)
459
+ print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
460
+ return updated_image_path
461
+
462
+
463
+ class ScribbleText2Image:
464
+ def __init__(self, device):
465
+ print("Initializing ScribbleText2Image to %s" % device)
466
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
467
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble",
468
+ torch_dtype=self.torch_dtype)
469
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
470
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
471
+ torch_dtype=self.torch_dtype
472
+ )
473
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
474
+ self.pipe.to(device)
475
+ self.seed = -1
476
+ self.a_prompt = 'best quality, extremely detailed'
477
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
478
+ 'fewer digits, cropped, worst quality, low quality'
479
+
480
+ @prompts(name="Generate Image Condition On Sketch Image",
481
+ description="useful when you want to generate a new real image from both the user desciption and "
482
+ "a scribble image or a sketch image. "
483
+ "The input to this tool should be a comma seperated string of two, "
484
+ "representing the image_path and the user description")
485
+ def inference(self, inputs):
486
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
487
+ image = Image.open(image_path)
488
+ self.seed = random.randint(0, 65535)
489
+ seed_everything(self.seed)
490
+ prompt = instruct_text + ', ' + self.a_prompt
491
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
492
+ guidance_scale=9.0).images[0]
493
+ updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
494
+ image.save(updated_image_path)
495
+ print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
496
+ f"Output Image: {updated_image_path}")
497
+ return updated_image_path
498
+
499
+
500
+ class Image2Pose:
501
+ def __init__(self, device):
502
+ print("Initializing Image2Pose")
503
+ self.detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
504
+
505
+ @prompts(name="Pose Detection On Image",
506
+ description="useful when you want to detect the human pose of the image. "
507
+ "like: generate human poses of this image, or generate a pose image from this image. "
508
+ "The input to this tool should be a string, representing the image_path")
509
+ def inference(self, inputs):
510
+ image = Image.open(inputs)
511
+ pose = self.detector(image)
512
+ updated_image_path = get_new_image_name(inputs, func_name="human-pose")
513
+ pose.save(updated_image_path)
514
+ print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
515
+ return updated_image_path
516
+
517
+
518
+ class PoseText2Image:
519
+ def __init__(self, device):
520
+ print("Initializing PoseText2Image to %s" % device)
521
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
522
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose",
523
+ torch_dtype=self.torch_dtype)
524
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
525
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
526
+ torch_dtype=self.torch_dtype)
527
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
528
+ self.pipe.to(device)
529
+ self.num_inference_steps = 20
530
+ self.seed = -1
531
+ self.unconditional_guidance_scale = 9.0
532
+ self.a_prompt = 'best quality, extremely detailed'
533
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
534
+ ' fewer digits, cropped, worst quality, low quality'
535
+
536
+ @prompts(name="Generate Image Condition On Pose Image",
537
+ description="useful when you want to generate a new real image from both the user desciption "
538
+ "and a human pose image. "
539
+ "like: generate a real image of a human from this human pose image, "
540
+ "or generate a new real image of a human from this pose. "
541
+ "The input to this tool should be a comma seperated string of two, "
542
+ "representing the image_path and the user description")
543
+ def inference(self, inputs):
544
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
545
+ image = Image.open(image_path)
546
+ self.seed = random.randint(0, 65535)
547
+ seed_everything(self.seed)
548
+ prompt = instruct_text + ', ' + self.a_prompt
549
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
550
+ guidance_scale=9.0).images[0]
551
+ updated_image_path = get_new_image_name(image_path, func_name="pose2image")
552
+ image.save(updated_image_path)
553
+ print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, "
554
+ f"Output Image: {updated_image_path}")
555
+ return updated_image_path
556
+
557
+
558
+ class Image2Seg:
559
+ def __init__(self, device):
560
+ print("Initializing Image2Seg")
561
+ self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
562
+ self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
563
+ self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
564
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
565
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
566
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
567
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
568
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
569
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
570
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
571
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
572
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
573
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
574
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
575
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
576
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
577
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
578
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
579
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
580
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
581
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
582
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
583
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
584
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
585
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
586
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
587
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
588
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
589
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
590
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
591
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
592
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
593
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
594
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
595
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
596
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
597
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
598
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
599
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
600
+ [102, 255, 0], [92, 0, 255]]
601
+
602
+ @prompts(name="Segmentation On Image",
603
+ description="useful when you want to detect segmentations of the image. "
604
+ "like: segment this image, or generate segmentations on this image, "
605
+ "or peform segmentation on this image. "
606
+ "The input to this tool should be a string, representing the image_path")
607
+ def inference(self, inputs):
608
+ image = Image.open(inputs)
609
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
610
+ with torch.no_grad():
611
+ outputs = self.image_segmentor(pixel_values)
612
+ seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
613
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
614
+ palette = np.array(self.ade_palette)
615
+ for label, color in enumerate(palette):
616
+ color_seg[seg == label, :] = color
617
+ color_seg = color_seg.astype(np.uint8)
618
+ segmentation = Image.fromarray(color_seg)
619
+ updated_image_path = get_new_image_name(inputs, func_name="segmentation")
620
+ segmentation.save(updated_image_path)
621
+ print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
622
+ return updated_image_path
623
+
624
+
625
+ class SegText2Image:
626
+ def __init__(self, device):
627
+ print("Initializing SegText2Image to %s" % device)
628
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
629
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg",
630
+ torch_dtype=self.torch_dtype)
631
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
632
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
633
+ torch_dtype=self.torch_dtype)
634
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
635
+ self.pipe.to(device)
636
+ self.seed = -1
637
+ self.a_prompt = 'best quality, extremely detailed'
638
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
639
+ ' fewer digits, cropped, worst quality, low quality'
640
+
641
+ @prompts(name="Generate Image Condition On Segmentations",
642
+ description="useful when you want to generate a new real image from both the user desciption and segmentations. "
643
+ "like: generate a real image of a object or something from this segmentation image, "
644
+ "or generate a new real image of a object or something from these segmentations. "
645
+ "The input to this tool should be a comma seperated string of two, "
646
+ "representing the image_path and the user description")
647
+ def inference(self, inputs):
648
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
649
+ image = Image.open(image_path)
650
+ self.seed = random.randint(0, 65535)
651
+ seed_everything(self.seed)
652
+ prompt = instruct_text + ', ' + self.a_prompt
653
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
654
+ guidance_scale=9.0).images[0]
655
+ updated_image_path = get_new_image_name(image_path, func_name="segment2image")
656
+ image.save(updated_image_path)
657
+ print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, "
658
+ f"Output Image: {updated_image_path}")
659
+ return updated_image_path
660
+
661
+
662
+ class Image2Depth:
663
+ def __init__(self, device):
664
+ print("Initializing Image2Depth")
665
+ self.depth_estimator = pipeline('depth-estimation')
666
+
667
+ @prompts(name="Predict Depth On Image",
668
+ description="useful when you want to detect depth of the image. like: generate the depth from this image, "
669
+ "or detect the depth map on this image, or predict the depth for this image. "
670
+ "The input to this tool should be a string, representing the image_path")
671
+ def inference(self, inputs):
672
+ image = Image.open(inputs)
673
+ depth = self.depth_estimator(image)['depth']
674
+ depth = np.array(depth)
675
+ depth = depth[:, :, None]
676
+ depth = np.concatenate([depth, depth, depth], axis=2)
677
+ depth = Image.fromarray(depth)
678
+ updated_image_path = get_new_image_name(inputs, func_name="depth")
679
+ depth.save(updated_image_path)
680
+ print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}")
681
+ return updated_image_path
682
+
683
+
684
+ class DepthText2Image:
685
+ def __init__(self, device):
686
+ print("Initializing DepthText2Image to %s" % device)
687
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
688
+ self.controlnet = ControlNetModel.from_pretrained(
689
+ "fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype)
690
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
691
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
692
+ torch_dtype=self.torch_dtype)
693
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
694
+ self.pipe.to(device)
695
+ self.seed = -1
696
+ self.a_prompt = 'best quality, extremely detailed'
697
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
698
+ ' fewer digits, cropped, worst quality, low quality'
699
+
700
+ @prompts(name="Generate Image Condition On Depth",
701
+ description="useful when you want to generate a new real image from both the user desciption and depth image. "
702
+ "like: generate a real image of a object or something from this depth image, "
703
+ "or generate a new real image of a object or something from the depth map. "
704
+ "The input to this tool should be a comma seperated string of two, "
705
+ "representing the image_path and the user description")
706
+ def inference(self, inputs):
707
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
708
+ image = Image.open(image_path)
709
+ self.seed = random.randint(0, 65535)
710
+ seed_everything(self.seed)
711
+ prompt = instruct_text + ', ' + self.a_prompt
712
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
713
+ guidance_scale=9.0).images[0]
714
+ updated_image_path = get_new_image_name(image_path, func_name="depth2image")
715
+ image.save(updated_image_path)
716
+ print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, "
717
+ f"Output Image: {updated_image_path}")
718
+ return updated_image_path
719
+
720
+
721
+ class Image2Normal:
722
+ def __init__(self, device):
723
+ print("Initializing Image2Normal")
724
+ self.depth_estimator = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas")
725
+ self.bg_threhold = 0.4
726
+
727
+ @prompts(name="Predict Normal Map On Image",
728
+ description="useful when you want to detect norm map of the image. "
729
+ "like: generate normal map from this image, or predict normal map of this image. "
730
+ "The input to this tool should be a string, representing the image_path")
731
+ def inference(self, inputs):
732
+ image = Image.open(inputs)
733
+ original_size = image.size
734
+ image = self.depth_estimator(image)['predicted_depth'][0]
735
+ image = image.numpy()
736
+ image_depth = image.copy()
737
+ image_depth -= np.min(image_depth)
738
+ image_depth /= np.max(image_depth)
739
+ x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
740
+ x[image_depth < self.bg_threhold] = 0
741
+ y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
742
+ y[image_depth < self.bg_threhold] = 0
743
+ z = np.ones_like(x) * np.pi * 2.0
744
+ image = np.stack([x, y, z], axis=2)
745
+ image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
746
+ image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
747
+ image = Image.fromarray(image)
748
+ image = image.resize(original_size)
749
+ updated_image_path = get_new_image_name(inputs, func_name="normal-map")
750
+ image.save(updated_image_path)
751
+ print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}")
752
+ return updated_image_path
753
+
754
+
755
+ class NormalText2Image:
756
+ def __init__(self, device):
757
+ print("Initializing NormalText2Image to %s" % device)
758
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
759
+ self.controlnet = ControlNetModel.from_pretrained(
760
+ "fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype)
761
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
762
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
763
+ torch_dtype=self.torch_dtype)
764
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
765
+ self.pipe.to(device)
766
+ self.seed = -1
767
+ self.a_prompt = 'best quality, extremely detailed'
768
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
769
+ ' fewer digits, cropped, worst quality, low quality'
770
+
771
+ @prompts(name="Generate Image Condition On Normal Map",
772
+ description="useful when you want to generate a new real image from both the user desciption and normal map. "
773
+ "like: generate a real image of a object or something from this normal map, "
774
+ "or generate a new real image of a object or something from the normal map. "
775
+ "The input to this tool should be a comma seperated string of two, "
776
+ "representing the image_path and the user description")
777
+ def inference(self, inputs):
778
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
779
+ image = Image.open(image_path)
780
+ self.seed = random.randint(0, 65535)
781
+ seed_everything(self.seed)
782
+ prompt = instruct_text + ', ' + self.a_prompt
783
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
784
+ guidance_scale=9.0).images[0]
785
+ updated_image_path = get_new_image_name(image_path, func_name="normal2image")
786
+ image.save(updated_image_path)
787
+ print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, "
788
+ f"Output Image: {updated_image_path}")
789
+ return updated_image_path
790
+
791
+
792
+ class VisualQuestionAnswering:
793
+ def __init__(self, device):
794
+ print("Initializing VisualQuestionAnswering to %s" % device)
795
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
796
+ self.device = device
797
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
798
+ self.model = BlipForQuestionAnswering.from_pretrained(
799
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
800
+
801
+ @prompts(name="Answer Question About The Image",
802
+ description="useful when you need an answer for a question based on an image. "
803
+ "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
804
+ "The input to this tool should be a comma seperated string of two, representing the image_path and the question")
805
+ def inference(self, inputs):
806
+ image_path, question = inputs.split(",")
807
+ raw_image = Image.open(image_path).convert('RGB')
808
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
809
+ out = self.model.generate(**inputs)
810
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
811
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
812
+ f"Output Answer: {answer}")
813
+ return answer
814
+
815
+
816
+ class ConversationBot:
817
+ def __init__(self, load_dict):
818
+ # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
819
+ print(f"Initializing VisualChatGPT, load_dict={load_dict}")
820
+ if 'ImageCaptioning' not in load_dict:
821
+ raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT")
822
+
823
+ self.llm = OpenAI(temperature=0)
824
+ self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
825
+
826
+ self.models = dict()
827
+ for class_name, device in load_dict.items():
828
+ self.models[class_name] = globals()[class_name](device=device)
829
+
830
+ self.tools = []
831
+ for class_name, instance in self.models.items():
832
+ for e in dir(instance):
833
+ if e.startswith('inference'):
834
+ func = getattr(instance, e)
835
+ self.tools.append(Tool(name=func.name, description=func.description, func=func))
836
+
837
+ self.agent = initialize_agent(
838
+ self.tools,
839
+ self.llm,
840
+ agent="conversational-react-description",
841
+ verbose=True,
842
+ memory=self.memory,
843
+ return_intermediate_steps=True,
844
+ agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
845
+ 'suffix': VISUAL_CHATGPT_SUFFIX}, )
846
+
847
+ def run_text(self, text, state):
848
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
849
+ res = self.agent({"input": text})
850
+ res['output'] = res['output'].replace("\\", "/")
851
+ response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
852
+ state = state + [(text, response)]
853
+ print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
854
+ f"Current Memory: {self.agent.memory.buffer}")
855
+ return state, state
856
+
857
+ def run_image(self, image, state, txt):
858
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
859
+ print("======>Auto Resize Image...")
860
+ img = Image.open(image.name)
861
+ width, height = img.size
862
+ ratio = min(512 / width, 512 / height)
863
+ width_new, height_new = (round(width * ratio), round(height * ratio))
864
+ width_new = int(np.round(width_new / 64.0)) * 64
865
+ height_new = int(np.round(height_new / 64.0)) * 64
866
+ img = img.resize((width_new, height_new))
867
+ img = img.convert('RGB')
868
+ img.save(image_filename, "PNG")
869
+ print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
870
+ description = self.models['ImageCaptioning'].inference(image_filename)
871
+ Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. " \
872
+ "This information helps you to understand this image, " \
873
+ "but you should use tools to finish following tasks, " \
874
+ "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(
875
+ image_filename, description)
876
+ AI_prompt = "Received. "
877
+ self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
878
+ state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
879
+ print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
880
+ f"Current Memory: {self.agent.memory.buffer}")
881
+ return state, state, txt + ' ' + image_filename + ' '
882
+
883
+
884
+ if __name__ == '__main__':
885
+ parser = argparse.ArgumentParser()
886
+ parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
887
+ args = parser.parse_args()
888
+ load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
889
+ bot = ConversationBot(load_dict=load_dict)
890
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
891
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
892
+ state = gr.State([])
893
+ with gr.Row():
894
+ with gr.Column(scale=0.7):
895
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
896
+ container=False)
897
+ with gr.Column(scale=0.15, min_width=0):
898
+ clear = gr.Button("Clear")
899
+ with gr.Column(scale=0.15, min_width=0):
900
+ btn = gr.UploadButton("Upload", file_types=["image"])
901
+
902
+ txt.submit(bot.run_text, [txt, state], [chatbot, state])
903
+ txt.submit(lambda: "", None, txt)
904
+ btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
905
+ clear.click(bot.memory.clear)
906
+ clear.click(lambda: [], None, chatbot)
907
+ clear.click(lambda: [], None, state)
908
+ demo.launch(server_name="0.0.0.0", server_port=7868)
visual_chatgpt_zh.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import random
4
+ import torch
5
+ import cv2
6
+ import re
7
+ import uuid
8
+ from PIL import Image
9
+ import numpy as np
10
+ import argparse
11
+
12
+ from langchain.agents.initialize import initialize_agent
13
+ from langchain.agents.tools import Tool
14
+ from langchain.chains.conversation.memory import ConversationBufferMemory
15
+ from langchain.llms.openai import OpenAI
16
+
17
+ from modules.image_captioning import ImageCaptioning
18
+ from modules.image_editing import ImageEditing
19
+ from modules.instruct_px2pix import InstructPix2Pix
20
+ from modules.mask_former import MaskFormer
21
+ from modules.text2img import Text2Image
22
+ from modules.visual_question_answering import VisualQuestionAnswering
23
+ from modules.controlnet_canny import Image2Canny,CannyText2Image
24
+ from modules.controlnet_depth import Image2Depth,DepthText2Image
25
+ from modules.controlnet_hed import Image2Hed,HedText2Image
26
+ from modules.controlnet_line import Image2Line,LineText2Image
27
+ from modules.controlnet_normal import Image2Normal,NormalText2Image
28
+ from modules.controlnet_pose import Image2Pose,PoseText2Image
29
+ from modules.controlnet_scibble import Image2Scribble,ScribbleText2Image
30
+ from modules.controlnet_seg import Image2Seg,SegText2Image
31
+
32
+ from modules.utils import *
33
+
34
+ import argparse
35
+
36
+ # chatgpt前缀
37
+ VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
38
+ Visual ChatGPT is able to process and understand large amounts of text and image. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
39
+ Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
40
+ Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
41
+ TOOLS:
42
+ ------
43
+ Visual ChatGPT has access to the following tools:"""
44
+
45
+ # 调教chatgpt的instruction
46
+ VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
47
+ ```
48
+ Thought: Do I need to use a tool? Yes
49
+ Action: the action to take, should be one of [{tool_names}]
50
+ Action Input: the input to the action
51
+ Observation: the result of the action
52
+ ```
53
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
54
+ ```
55
+ Thought: Do I need to use a tool? No
56
+ {ai_prefix}: [your response here]
57
+ ```
58
+ """
59
+
60
+ # chatgpt后缀
61
+ VISUAL_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if not exists.
62
+ You will remember to provide the image file name loyally if it's provided in the last tool observation.
63
+ Begin!
64
+ Previous conversation history:
65
+ {chat_history}
66
+ New input: {input}
67
+ Since Visual ChatGPT is a text language model, Visual ChatGPT must use tools to observe images rather than imagination.
68
+ The thoughts and observations are only visible for Visual ChatGPT, Visual ChatGPT should remember to repeat important information in the final response for Human.
69
+ Thought: Do I need to use a tool? {agent_scratchpad}"""
70
+
71
+ os.makedirs('image', exist_ok=True)
72
+
73
+
74
+ class ConversationBot:
75
+ def __init__(self, load_dict, pretrained_model_dir):
76
+ # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
77
+ print(f"Initializing VisualChatGPT, load_dict={load_dict}")
78
+ if 'ImageCaptioning' not in load_dict:
79
+ raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT")
80
+
81
+ self.llm = OpenAI(temperature=0)
82
+ self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
83
+
84
+ self.models = dict()
85
+ for class_name, device in load_dict.items():
86
+ self.models[class_name] = globals()[class_name](device=device, pretrained_model_dir=pretrained_model_dir)
87
+
88
+ self.tools = []
89
+ for class_name, instance in self.models.items():
90
+ for e in dir(instance):
91
+ if e.startswith('inference'):
92
+ func = getattr(instance, e)
93
+ self.tools.append(Tool(name=func.name, description=func.description, func=func))
94
+
95
+ self.agent = initialize_agent(
96
+ self.tools,
97
+ self.llm,
98
+ agent="conversational-react-description",
99
+ verbose=True,
100
+ memory=self.memory,
101
+ return_intermediate_steps=True,
102
+ agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
103
+ 'suffix': VISUAL_CHATGPT_SUFFIX}, )
104
+
105
+ def run_text(self, text, state):
106
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
107
+ res = self.agent({"input": text})
108
+ res['output'] = res['output'].replace("\\", "/")
109
+ response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
110
+ state = state + [(text, response)]
111
+ print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
112
+ f"Current Memory: {self.agent.memory.buffer}")
113
+ return state, state
114
+
115
+ def run_image(self, image, state, txt):
116
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
117
+ print("======>Auto Resize Image...")
118
+ img = Image.open(image.name)
119
+ width, height = img.size
120
+ ratio = min(512 / width, 512 / height)
121
+ width_new, height_new = (round(width * ratio), round(height * ratio))
122
+ width_new = int(np.round(width_new / 64.0)) * 64
123
+ height_new = int(np.round(height_new / 64.0)) * 64
124
+ img = img.resize((width_new, height_new))
125
+ img = img.convert('RGB')
126
+ img.save(image_filename, "PNG")
127
+ print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
128
+ description = self.models['ImageCaptioning'].inference(image_filename)
129
+ Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. " \
130
+ "This information helps you to understand this image, " \
131
+ "but you should use tools to finish following tasks, " \
132
+ "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(
133
+ image_filename, description)
134
+ AI_prompt = "Received. "
135
+ self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
136
+ state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
137
+ print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
138
+ f"Current Memory: {self.agent.memory.buffer}")
139
+ return state, state, txt + ' ' + image_filename + ' '
140
+
141
+
142
+ if __name__ == '__main__':
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
145
+ parser.add_argument("--pretrained_model_dir", default="./hf_models_path",
146
+ type=str, help="huggingface下载好的模型路径")
147
+ args = parser.parse_args()
148
+
149
+ pretrained_model_dir = args.pretrained_model_dir
150
+
151
+ load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
152
+ bot = ConversationBot(load_dict=load_dict, pretrained_model_dir=pretrained_model_dir)
153
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
154
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
155
+ state = gr.State([])
156
+ with gr.Row():
157
+ with gr.Column(scale=0.7):
158
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
159
+ container=False)
160
+ with gr.Column(scale=0.15, min_width=0):
161
+ clear = gr.Button("Clear")
162
+ with gr.Column(scale=0.15, min_width=0):
163
+ btn = gr.UploadButton("Upload", file_types=["image"])
164
+
165
+ txt.submit(bot.run_text, [txt, state], [chatbot, state])
166
+ txt.submit(lambda: "", None, txt)
167
+ btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
168
+ clear.click(bot.memory.clear)
169
+ clear.click(lambda: [], None, chatbot)
170
+ clear.click(lambda: [], None, state)
171
+ demo.launch(server_name="0.0.0.0", server_port=7868)