Spaces:
Runtime error
Runtime error
kiramayatu
commited on
Commit
•
67b73a1
1
Parent(s):
e4a4487
Upload 16 files
Browse files- .idea/.gitignore +3 -0
- .idea/VITS_voice_conversion.iml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +154 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- configs/modified_finetune_speaker.json +172 -0
- configs/uma_trilingual.json +54 -0
- inference/G_latest.pth +3 -0
- inference/ONNXVITS_inference.py +36 -0
- inference/VC_inference.py +139 -0
- inference/finetune_speaker.json +147 -0
- monotonic_align/__init__.py +19 -0
- monotonic_align/core.pyx +42 -0
- monotonic_align/setup.py +9 -0
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/VITS_voice_conversion.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.7 (VITS)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="PLAIN" />
|
10 |
+
<option name="myDocStringFormat" value="Plain" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="132">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="ccxt" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="lz4" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="pre-commit" />
|
11 |
+
<item index="3" class="java.lang.String" itemvalue="elegantrl" />
|
12 |
+
<item index="4" class="java.lang.String" itemvalue="setuptools" />
|
13 |
+
<item index="5" class="java.lang.String" itemvalue="ray" />
|
14 |
+
<item index="6" class="java.lang.String" itemvalue="gputil" />
|
15 |
+
<item index="7" class="java.lang.String" itemvalue="google-pasta" />
|
16 |
+
<item index="8" class="java.lang.String" itemvalue="tensorflow-estimator" />
|
17 |
+
<item index="9" class="java.lang.String" itemvalue="scikit-learn" />
|
18 |
+
<item index="10" class="java.lang.String" itemvalue="tabulate" />
|
19 |
+
<item index="11" class="java.lang.String" itemvalue="multitasking" />
|
20 |
+
<item index="12" class="java.lang.String" itemvalue="pickleshare" />
|
21 |
+
<item index="13" class="java.lang.String" itemvalue="pyasn1-modules" />
|
22 |
+
<item index="14" class="java.lang.String" itemvalue="ipython-genutils" />
|
23 |
+
<item index="15" class="java.lang.String" itemvalue="Pygments" />
|
24 |
+
<item index="16" class="java.lang.String" itemvalue="mccabe" />
|
25 |
+
<item index="17" class="java.lang.String" itemvalue="astunparse" />
|
26 |
+
<item index="18" class="java.lang.String" itemvalue="lxml" />
|
27 |
+
<item index="19" class="java.lang.String" itemvalue="Werkzeug" />
|
28 |
+
<item index="20" class="java.lang.String" itemvalue="tensorboard-data-server" />
|
29 |
+
<item index="21" class="java.lang.String" itemvalue="jupyter-client" />
|
30 |
+
<item index="22" class="java.lang.String" itemvalue="pexpect" />
|
31 |
+
<item index="23" class="java.lang.String" itemvalue="click" />
|
32 |
+
<item index="24" class="java.lang.String" itemvalue="ipykernel" />
|
33 |
+
<item index="25" class="java.lang.String" itemvalue="pandas-datareader" />
|
34 |
+
<item index="26" class="java.lang.String" itemvalue="psutil" />
|
35 |
+
<item index="27" class="java.lang.String" itemvalue="jedi" />
|
36 |
+
<item index="28" class="java.lang.String" itemvalue="regex" />
|
37 |
+
<item index="29" class="java.lang.String" itemvalue="tensorboard" />
|
38 |
+
<item index="30" class="java.lang.String" itemvalue="platformdirs" />
|
39 |
+
<item index="31" class="java.lang.String" itemvalue="matplotlib" />
|
40 |
+
<item index="32" class="java.lang.String" itemvalue="idna" />
|
41 |
+
<item index="33" class="java.lang.String" itemvalue="rsa" />
|
42 |
+
<item index="34" class="java.lang.String" itemvalue="decorator" />
|
43 |
+
<item index="35" class="java.lang.String" itemvalue="numpy" />
|
44 |
+
<item index="36" class="java.lang.String" itemvalue="pyasn1" />
|
45 |
+
<item index="37" class="java.lang.String" itemvalue="requests" />
|
46 |
+
<item index="38" class="java.lang.String" itemvalue="tensorflow" />
|
47 |
+
<item index="39" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
|
48 |
+
<item index="40" class="java.lang.String" itemvalue="Deprecated" />
|
49 |
+
<item index="41" class="java.lang.String" itemvalue="nest-asyncio" />
|
50 |
+
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
|
51 |
+
<item index="43" class="java.lang.String" itemvalue="keras-tuner" />
|
52 |
+
<item index="44" class="java.lang.String" itemvalue="scipy" />
|
53 |
+
<item index="45" class="java.lang.String" itemvalue="dataclasses" />
|
54 |
+
<item index="46" class="java.lang.String" itemvalue="tornado" />
|
55 |
+
<item index="47" class="java.lang.String" itemvalue="google-auth-oauthlib" />
|
56 |
+
<item index="48" class="java.lang.String" itemvalue="black" />
|
57 |
+
<item index="49" class="java.lang.String" itemvalue="toml" />
|
58 |
+
<item index="50" class="java.lang.String" itemvalue="Quandl" />
|
59 |
+
<item index="51" class="java.lang.String" itemvalue="pandas" />
|
60 |
+
<item index="52" class="java.lang.String" itemvalue="termcolor" />
|
61 |
+
<item index="53" class="java.lang.String" itemvalue="pylint" />
|
62 |
+
<item index="54" class="java.lang.String" itemvalue="typing_extensions" />
|
63 |
+
<item index="55" class="java.lang.String" itemvalue="cachetools" />
|
64 |
+
<item index="56" class="java.lang.String" itemvalue="debugpy" />
|
65 |
+
<item index="57" class="java.lang.String" itemvalue="isort" />
|
66 |
+
<item index="58" class="java.lang.String" itemvalue="pytz" />
|
67 |
+
<item index="59" class="java.lang.String" itemvalue="inflection" />
|
68 |
+
<item index="60" class="java.lang.String" itemvalue="Pillow" />
|
69 |
+
<item index="61" class="java.lang.String" itemvalue="traitlets" />
|
70 |
+
<item index="62" class="java.lang.String" itemvalue="absl-py" />
|
71 |
+
<item index="63" class="java.lang.String" itemvalue="protobuf" />
|
72 |
+
<item index="64" class="java.lang.String" itemvalue="joblib" />
|
73 |
+
<item index="65" class="java.lang.String" itemvalue="threadpoolctl" />
|
74 |
+
<item index="66" class="java.lang.String" itemvalue="opt-einsum" />
|
75 |
+
<item index="67" class="java.lang.String" itemvalue="python-dateutil" />
|
76 |
+
<item index="68" class="java.lang.String" itemvalue="gpflow" />
|
77 |
+
<item index="69" class="java.lang.String" itemvalue="astroid" />
|
78 |
+
<item index="70" class="java.lang.String" itemvalue="cycler" />
|
79 |
+
<item index="71" class="java.lang.String" itemvalue="gast" />
|
80 |
+
<item index="72" class="java.lang.String" itemvalue="kt-legacy" />
|
81 |
+
<item index="73" class="java.lang.String" itemvalue="appdirs" />
|
82 |
+
<item index="74" class="java.lang.String" itemvalue="tensorflow-probability" />
|
83 |
+
<item index="75" class="java.lang.String" itemvalue="pip" />
|
84 |
+
<item index="76" class="java.lang.String" itemvalue="pyzmq" />
|
85 |
+
<item index="77" class="java.lang.String" itemvalue="certifi" />
|
86 |
+
<item index="78" class="java.lang.String" itemvalue="oauthlib" />
|
87 |
+
<item index="79" class="java.lang.String" itemvalue="pyparsing" />
|
88 |
+
<item index="80" class="java.lang.String" itemvalue="Markdown" />
|
89 |
+
<item index="81" class="java.lang.String" itemvalue="h5py" />
|
90 |
+
<item index="82" class="java.lang.String" itemvalue="wrapt" />
|
91 |
+
<item index="83" class="java.lang.String" itemvalue="kiwisolver" />
|
92 |
+
<item index="84" class="java.lang.String" itemvalue="empyrical" />
|
93 |
+
<item index="85" class="java.lang.String" itemvalue="backcall" />
|
94 |
+
<item index="86" class="java.lang.String" itemvalue="charset-normalizer" />
|
95 |
+
<item index="87" class="java.lang.String" itemvalue="multipledispatch" />
|
96 |
+
<item index="88" class="java.lang.String" itemvalue="pathspec" />
|
97 |
+
<item index="89" class="java.lang.String" itemvalue="jupyter-core" />
|
98 |
+
<item index="90" class="java.lang.String" itemvalue="matplotlib-inline" />
|
99 |
+
<item index="91" class="java.lang.String" itemvalue="ptyprocess" />
|
100 |
+
<item index="92" class="java.lang.String" itemvalue="more-itertools" />
|
101 |
+
<item index="93" class="java.lang.String" itemvalue="mypy-extensions" />
|
102 |
+
<item index="94" class="java.lang.String" itemvalue="cloudpickle" />
|
103 |
+
<item index="95" class="java.lang.String" itemvalue="wcwidth" />
|
104 |
+
<item index="96" class="java.lang.String" itemvalue="requests-oauthlib" />
|
105 |
+
<item index="97" class="java.lang.String" itemvalue="Keras-Preprocessing" />
|
106 |
+
<item index="98" class="java.lang.String" itemvalue="yfinance" />
|
107 |
+
<item index="99" class="java.lang.String" itemvalue="tomli" />
|
108 |
+
<item index="100" class="java.lang.String" itemvalue="urllib3" />
|
109 |
+
<item index="101" class="java.lang.String" itemvalue="six" />
|
110 |
+
<item index="102" class="java.lang.String" itemvalue="parso" />
|
111 |
+
<item index="103" class="java.lang.String" itemvalue="wheel" />
|
112 |
+
<item index="104" class="java.lang.String" itemvalue="ipython" />
|
113 |
+
<item index="105" class="java.lang.String" itemvalue="packaging" />
|
114 |
+
<item index="106" class="java.lang.String" itemvalue="lazy-object-proxy" />
|
115 |
+
<item index="107" class="java.lang.String" itemvalue="grpcio" />
|
116 |
+
<item index="108" class="java.lang.String" itemvalue="dm-tree" />
|
117 |
+
<item index="109" class="java.lang.String" itemvalue="google-auth" />
|
118 |
+
<item index="110" class="java.lang.String" itemvalue="seaborn" />
|
119 |
+
<item index="111" class="java.lang.String" itemvalue="thop" />
|
120 |
+
<item index="112" class="java.lang.String" itemvalue="torch" />
|
121 |
+
<item index="113" class="java.lang.String" itemvalue="torchvision" />
|
122 |
+
<item index="114" class="java.lang.String" itemvalue="d2l" />
|
123 |
+
<item index="115" class="java.lang.String" itemvalue="keyboard" />
|
124 |
+
<item index="116" class="java.lang.String" itemvalue="transformers" />
|
125 |
+
<item index="117" class="java.lang.String" itemvalue="phonemizer" />
|
126 |
+
<item index="118" class="java.lang.String" itemvalue="Unidecode" />
|
127 |
+
<item index="119" class="java.lang.String" itemvalue="nltk" />
|
128 |
+
<item index="120" class="java.lang.String" itemvalue="pinecone-client" />
|
129 |
+
<item index="121" class="java.lang.String" itemvalue="sentence-transformers" />
|
130 |
+
<item index="122" class="java.lang.String" itemvalue="whisper" />
|
131 |
+
<item index="123" class="java.lang.String" itemvalue="datasets" />
|
132 |
+
<item index="124" class="java.lang.String" itemvalue="pyaudio" />
|
133 |
+
<item index="125" class="java.lang.String" itemvalue="torchsummary" />
|
134 |
+
<item index="126" class="java.lang.String" itemvalue="openjtalk" />
|
135 |
+
<item index="127" class="java.lang.String" itemvalue="hydra-core" />
|
136 |
+
<item index="128" class="java.lang.String" itemvalue="museval" />
|
137 |
+
<item index="129" class="java.lang.String" itemvalue="mypy" />
|
138 |
+
<item index="130" class="java.lang.String" itemvalue="hydra-colorlog" />
|
139 |
+
<item index="131" class="java.lang.String" itemvalue="flake8" />
|
140 |
+
</list>
|
141 |
+
</value>
|
142 |
+
</option>
|
143 |
+
</inspection_tool>
|
144 |
+
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
145 |
+
<option name="ignoredIdentifiers">
|
146 |
+
<list>
|
147 |
+
<option value="sentiment_classification.model_predictions.audio_path" />
|
148 |
+
<option value="sentiment_classification.model_predictions.sample_rate" />
|
149 |
+
<option value="sentiment_classification.model_predictions.num_samples" />
|
150 |
+
</list>
|
151 |
+
</option>
|
152 |
+
</inspection_tool>
|
153 |
+
</profile>
|
154 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (VITS)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/VITS_voice_conversion.iml" filepath="$PROJECT_DIR$/.idea/VITS_voice_conversion.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
configs/modified_finetune_speaker.json
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 10,
|
4 |
+
"eval_interval": 100,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 0.0002,
|
8 |
+
"betas": [
|
9 |
+
0.8,
|
10 |
+
0.99
|
11 |
+
],
|
12 |
+
"eps": 1e-09,
|
13 |
+
"batch_size": 16,
|
14 |
+
"fp16_run": true,
|
15 |
+
"lr_decay": 0.999875,
|
16 |
+
"segment_size": 8192,
|
17 |
+
"init_lr_ratio": 1,
|
18 |
+
"warmup_epochs": 0,
|
19 |
+
"c_mel": 45,
|
20 |
+
"c_kl": 1.0
|
21 |
+
},
|
22 |
+
"data": {
|
23 |
+
"training_files": "final_annotation_train.txt",
|
24 |
+
"validation_files": "final_annotation_val.txt",
|
25 |
+
"text_cleaners": [
|
26 |
+
"chinese_cleaners"
|
27 |
+
],
|
28 |
+
"max_wav_value": 32768.0,
|
29 |
+
"sampling_rate": 22050,
|
30 |
+
"filter_length": 1024,
|
31 |
+
"hop_length": 256,
|
32 |
+
"win_length": 1024,
|
33 |
+
"n_mel_channels": 80,
|
34 |
+
"mel_fmin": 0.0,
|
35 |
+
"mel_fmax": null,
|
36 |
+
"add_blank": true,
|
37 |
+
"n_speakers": 2,
|
38 |
+
"cleaned_text": true
|
39 |
+
},
|
40 |
+
"model": {
|
41 |
+
"inter_channels": 192,
|
42 |
+
"hidden_channels": 192,
|
43 |
+
"filter_channels": 768,
|
44 |
+
"n_heads": 2,
|
45 |
+
"n_layers": 6,
|
46 |
+
"kernel_size": 3,
|
47 |
+
"p_dropout": 0.1,
|
48 |
+
"resblock": "1",
|
49 |
+
"resblock_kernel_sizes": [
|
50 |
+
3,
|
51 |
+
7,
|
52 |
+
11
|
53 |
+
],
|
54 |
+
"resblock_dilation_sizes": [
|
55 |
+
[
|
56 |
+
1,
|
57 |
+
3,
|
58 |
+
5
|
59 |
+
],
|
60 |
+
[
|
61 |
+
1,
|
62 |
+
3,
|
63 |
+
5
|
64 |
+
],
|
65 |
+
[
|
66 |
+
1,
|
67 |
+
3,
|
68 |
+
5
|
69 |
+
]
|
70 |
+
],
|
71 |
+
"upsample_rates": [
|
72 |
+
8,
|
73 |
+
8,
|
74 |
+
2,
|
75 |
+
2
|
76 |
+
],
|
77 |
+
"upsample_initial_channel": 512,
|
78 |
+
"upsample_kernel_sizes": [
|
79 |
+
16,
|
80 |
+
16,
|
81 |
+
4,
|
82 |
+
4
|
83 |
+
],
|
84 |
+
"n_layers_q": 3,
|
85 |
+
"use_spectral_norm": false,
|
86 |
+
"gin_channels": 256
|
87 |
+
},
|
88 |
+
"symbols": [
|
89 |
+
"_",
|
90 |
+
"\uff1b",
|
91 |
+
"\uff1a",
|
92 |
+
"\uff0c",
|
93 |
+
"\u3002",
|
94 |
+
"\uff01",
|
95 |
+
"\uff1f",
|
96 |
+
"-",
|
97 |
+
"\u201c",
|
98 |
+
"\u201d",
|
99 |
+
"\u300a",
|
100 |
+
"\u300b",
|
101 |
+
"\u3001",
|
102 |
+
"\uff08",
|
103 |
+
"\uff09",
|
104 |
+
"\u2026",
|
105 |
+
"\u2014",
|
106 |
+
" ",
|
107 |
+
"A",
|
108 |
+
"B",
|
109 |
+
"C",
|
110 |
+
"D",
|
111 |
+
"E",
|
112 |
+
"F",
|
113 |
+
"G",
|
114 |
+
"H",
|
115 |
+
"I",
|
116 |
+
"J",
|
117 |
+
"K",
|
118 |
+
"L",
|
119 |
+
"M",
|
120 |
+
"N",
|
121 |
+
"O",
|
122 |
+
"P",
|
123 |
+
"Q",
|
124 |
+
"R",
|
125 |
+
"S",
|
126 |
+
"T",
|
127 |
+
"U",
|
128 |
+
"V",
|
129 |
+
"W",
|
130 |
+
"X",
|
131 |
+
"Y",
|
132 |
+
"Z",
|
133 |
+
"a",
|
134 |
+
"b",
|
135 |
+
"c",
|
136 |
+
"d",
|
137 |
+
"e",
|
138 |
+
"f",
|
139 |
+
"g",
|
140 |
+
"h",
|
141 |
+
"i",
|
142 |
+
"j",
|
143 |
+
"k",
|
144 |
+
"l",
|
145 |
+
"m",
|
146 |
+
"n",
|
147 |
+
"o",
|
148 |
+
"p",
|
149 |
+
"q",
|
150 |
+
"r",
|
151 |
+
"s",
|
152 |
+
"t",
|
153 |
+
"u",
|
154 |
+
"v",
|
155 |
+
"w",
|
156 |
+
"x",
|
157 |
+
"y",
|
158 |
+
"z",
|
159 |
+
"1",
|
160 |
+
"2",
|
161 |
+
"3",
|
162 |
+
"4",
|
163 |
+
"5",
|
164 |
+
"0",
|
165 |
+
"\uff22",
|
166 |
+
"\uff30"
|
167 |
+
],
|
168 |
+
"speakers": {
|
169 |
+
"dingzhen": 0,
|
170 |
+
"taffy": 1
|
171 |
+
}
|
172 |
+
}
|
configs/uma_trilingual.json
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"eval_interval": 1000,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 2e-4,
|
8 |
+
"betas": [0.8, 0.99],
|
9 |
+
"eps": 1e-9,
|
10 |
+
"batch_size": 16,
|
11 |
+
"fp16_run": true,
|
12 |
+
"lr_decay": 0.999875,
|
13 |
+
"segment_size": 8192,
|
14 |
+
"init_lr_ratio": 1,
|
15 |
+
"warmup_epochs": 0,
|
16 |
+
"c_mel": 45,
|
17 |
+
"c_kl": 1.0
|
18 |
+
},
|
19 |
+
"data": {
|
20 |
+
"training_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.train.txt.cleaned",
|
21 |
+
"validation_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.val.txt.cleaned",
|
22 |
+
"text_cleaners":["cjke_cleaners2"],
|
23 |
+
"max_wav_value": 32768.0,
|
24 |
+
"sampling_rate": 22050,
|
25 |
+
"filter_length": 1024,
|
26 |
+
"hop_length": 256,
|
27 |
+
"win_length": 1024,
|
28 |
+
"n_mel_channels": 80,
|
29 |
+
"mel_fmin": 0.0,
|
30 |
+
"mel_fmax": null,
|
31 |
+
"add_blank": true,
|
32 |
+
"n_speakers": 999,
|
33 |
+
"cleaned_text": true
|
34 |
+
},
|
35 |
+
"model": {
|
36 |
+
"inter_channels": 192,
|
37 |
+
"hidden_channels": 192,
|
38 |
+
"filter_channels": 768,
|
39 |
+
"n_heads": 2,
|
40 |
+
"n_layers": 6,
|
41 |
+
"kernel_size": 3,
|
42 |
+
"p_dropout": 0.1,
|
43 |
+
"resblock": "1",
|
44 |
+
"resblock_kernel_sizes": [3,7,11],
|
45 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
46 |
+
"upsample_rates": [8,8,2,2],
|
47 |
+
"upsample_initial_channel": 512,
|
48 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
49 |
+
"n_layers_q": 3,
|
50 |
+
"use_spectral_norm": false,
|
51 |
+
"gin_channels": 256
|
52 |
+
},
|
53 |
+
"symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "N", "Q", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "s", "t", "u", "v", "w", "x", "y", "z", "\u0251", "\u00e6", "\u0283", "\u0291", "\u00e7", "\u026f", "\u026a", "\u0254", "\u025b", "\u0279", "\u00f0", "\u0259", "\u026b", "\u0265", "\u0278", "\u028a", "\u027e", "\u0292", "\u03b8", "\u03b2", "\u014b", "\u0266", "\u207c", "\u02b0", "`", "^", "#", "*", "=", "\u02c8", "\u02cc", "\u2192", "\u2193", "\u2191", " "]
|
54 |
+
}
|
inference/G_latest.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44f9141fcac34c950376594d08a288d9159a32d6add851155b6fd0ecee242419
|
3 |
+
size 158887401
|
inference/ONNXVITS_inference.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logging.getLogger('numba').setLevel(logging.WARNING)
|
3 |
+
import IPython.display as ipd
|
4 |
+
import torch
|
5 |
+
import commons
|
6 |
+
import utils
|
7 |
+
import ONNXVITS_infer
|
8 |
+
from text import text_to_sequence
|
9 |
+
|
10 |
+
def get_text(text, hps):
|
11 |
+
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
|
12 |
+
if hps.data.add_blank:
|
13 |
+
text_norm = commons.intersperse(text_norm, 0)
|
14 |
+
text_norm = torch.LongTensor(text_norm)
|
15 |
+
return text_norm
|
16 |
+
|
17 |
+
hps = utils.get_hparams_from_file("../vits/pretrained_models/uma87.json")
|
18 |
+
|
19 |
+
net_g = ONNXVITS_infer.SynthesizerTrn(
|
20 |
+
len(hps.symbols),
|
21 |
+
hps.data.filter_length // 2 + 1,
|
22 |
+
hps.train.segment_size // hps.data.hop_length,
|
23 |
+
n_speakers=hps.data.n_speakers,
|
24 |
+
**hps.model)
|
25 |
+
_ = net_g.eval()
|
26 |
+
|
27 |
+
_ = utils.load_checkpoint("../vits/pretrained_models/uma_1153000.pth", net_g)
|
28 |
+
|
29 |
+
text1 = get_text("おはようございます。", hps)
|
30 |
+
stn_tst = text1
|
31 |
+
with torch.no_grad():
|
32 |
+
x_tst = stn_tst.unsqueeze(0)
|
33 |
+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
|
34 |
+
sid = torch.LongTensor([0])
|
35 |
+
audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
|
36 |
+
print(audio)
|
inference/VC_inference.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import no_grad, LongTensor
|
5 |
+
import argparse
|
6 |
+
import commons
|
7 |
+
from mel_processing import spectrogram_torch
|
8 |
+
import utils
|
9 |
+
from models import SynthesizerTrn
|
10 |
+
import gradio as gr
|
11 |
+
import librosa
|
12 |
+
import webbrowser
|
13 |
+
|
14 |
+
from text import text_to_sequence, _clean_text
|
15 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
16 |
+
language_marks = {
|
17 |
+
"Japanese": "",
|
18 |
+
"日本語": "[JA]",
|
19 |
+
"简体中文": "[ZH]",
|
20 |
+
"English": "[EN]",
|
21 |
+
"Mix": "",
|
22 |
+
}
|
23 |
+
lang = ['日本語', '简体中文', 'English', 'Mix']
|
24 |
+
def get_text(text, hps, is_symbol):
|
25 |
+
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
|
26 |
+
if hps.data.add_blank:
|
27 |
+
text_norm = commons.intersperse(text_norm, 0)
|
28 |
+
text_norm = LongTensor(text_norm)
|
29 |
+
return text_norm
|
30 |
+
|
31 |
+
def create_tts_fn(model, hps, speaker_ids):
|
32 |
+
def tts_fn(text, speaker, language, speed):
|
33 |
+
if language is not None:
|
34 |
+
text = language_marks[language] + text + language_marks[language]
|
35 |
+
speaker_id = speaker_ids[speaker]
|
36 |
+
stn_tst = get_text(text, hps, False)
|
37 |
+
with no_grad():
|
38 |
+
x_tst = stn_tst.unsqueeze(0).to(device)
|
39 |
+
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
|
40 |
+
sid = LongTensor([speaker_id]).to(device)
|
41 |
+
audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
|
42 |
+
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
|
43 |
+
del stn_tst, x_tst, x_tst_lengths, sid
|
44 |
+
return "Success", (hps.data.sampling_rate, audio)
|
45 |
+
|
46 |
+
return tts_fn
|
47 |
+
|
48 |
+
def create_vc_fn(model, hps, speaker_ids):
|
49 |
+
def vc_fn(original_speaker, target_speaker, record_audio, upload_audio):
|
50 |
+
input_audio = record_audio if record_audio is not None else upload_audio
|
51 |
+
if input_audio is None:
|
52 |
+
return "You need to record or upload an audio", None
|
53 |
+
sampling_rate, audio = input_audio
|
54 |
+
original_speaker_id = speaker_ids[original_speaker]
|
55 |
+
target_speaker_id = speaker_ids[target_speaker]
|
56 |
+
|
57 |
+
audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
|
58 |
+
if len(audio.shape) > 1:
|
59 |
+
audio = librosa.to_mono(audio.transpose(1, 0))
|
60 |
+
if sampling_rate != hps.data.sampling_rate:
|
61 |
+
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate)
|
62 |
+
with no_grad():
|
63 |
+
y = torch.FloatTensor(audio)
|
64 |
+
y = y / max(-y.min(), y.max()) / 0.99
|
65 |
+
y = y.to(device)
|
66 |
+
y = y.unsqueeze(0)
|
67 |
+
spec = spectrogram_torch(y, hps.data.filter_length,
|
68 |
+
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
|
69 |
+
center=False).to(device)
|
70 |
+
spec_lengths = LongTensor([spec.size(-1)]).to(device)
|
71 |
+
sid_src = LongTensor([original_speaker_id]).to(device)
|
72 |
+
sid_tgt = LongTensor([target_speaker_id]).to(device)
|
73 |
+
audio = model.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][
|
74 |
+
0, 0].data.cpu().float().numpy()
|
75 |
+
del y, spec, spec_lengths, sid_src, sid_tgt
|
76 |
+
return "Success", (hps.data.sampling_rate, audio)
|
77 |
+
|
78 |
+
return vc_fn
|
79 |
+
if __name__ == "__main__":
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument("--model_dir", default="./G_latest.pth", help="directory to your fine-tuned model")
|
82 |
+
parser.add_argument("--config_dir", default="./finetune_speaker.json", help="directory to your model config file")
|
83 |
+
parser.add_argument("--share", default=False, help="make link public (used in colab)")
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
hps = utils.get_hparams_from_file(args.config_dir)
|
87 |
+
|
88 |
+
|
89 |
+
net_g = SynthesizerTrn(
|
90 |
+
len(hps.symbols),
|
91 |
+
hps.data.filter_length // 2 + 1,
|
92 |
+
hps.train.segment_size // hps.data.hop_length,
|
93 |
+
n_speakers=hps.data.n_speakers,
|
94 |
+
**hps.model).to(device)
|
95 |
+
_ = net_g.eval()
|
96 |
+
|
97 |
+
_ = utils.load_checkpoint(args.model_dir, net_g, None)
|
98 |
+
speaker_ids = hps.speakers
|
99 |
+
speakers = list(hps.speakers.keys())
|
100 |
+
tts_fn = create_tts_fn(net_g, hps, speaker_ids)
|
101 |
+
vc_fn = create_vc_fn(net_g, hps, speaker_ids)
|
102 |
+
app = gr.Blocks()
|
103 |
+
with app:
|
104 |
+
with gr.Tab("Text-to-Speech"):
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
textbox = gr.TextArea(label="Text",
|
108 |
+
placeholder="Type your sentence here",
|
109 |
+
value="こんにちわ。", elem_id=f"tts-input")
|
110 |
+
# select character
|
111 |
+
char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character')
|
112 |
+
language_dropdown = gr.Dropdown(choices=lang, value=lang[0], label='language')
|
113 |
+
duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
|
114 |
+
label='速度 Speed')
|
115 |
+
with gr.Column():
|
116 |
+
text_output = gr.Textbox(label="Message")
|
117 |
+
audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
|
118 |
+
btn = gr.Button("Generate!")
|
119 |
+
btn.click(tts_fn,
|
120 |
+
inputs=[textbox, char_dropdown, language_dropdown, duration_slider,],
|
121 |
+
outputs=[text_output, audio_output])
|
122 |
+
with gr.Tab("Voice Conversion"):
|
123 |
+
gr.Markdown("""
|
124 |
+
录制或上传声音,并选择要转换的音色。
|
125 |
+
""")
|
126 |
+
with gr.Column():
|
127 |
+
record_audio = gr.Audio(label="record your voice", source="microphone")
|
128 |
+
upload_audio = gr.Audio(label="or upload audio here", source="upload")
|
129 |
+
source_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="source speaker")
|
130 |
+
target_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="target speaker")
|
131 |
+
with gr.Column():
|
132 |
+
message_box = gr.Textbox(label="Message")
|
133 |
+
converted_audio = gr.Audio(label='converted audio')
|
134 |
+
btn = gr.Button("Convert!")
|
135 |
+
btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio],
|
136 |
+
outputs=[message_box, converted_audio])
|
137 |
+
webbrowser.open("http://127.0.0.1:7860")
|
138 |
+
app.launch(share=args.share)
|
139 |
+
|
inference/finetune_speaker.json
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 100,
|
4 |
+
"eval_interval": 1000,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 0.0002,
|
8 |
+
"betas": [
|
9 |
+
0.8,
|
10 |
+
0.99
|
11 |
+
],
|
12 |
+
"eps": 1e-09,
|
13 |
+
"batch_size": 16,
|
14 |
+
"fp16_run": true,
|
15 |
+
"lr_decay": 0.999875,
|
16 |
+
"segment_size": 8192,
|
17 |
+
"init_lr_ratio": 1,
|
18 |
+
"warmup_epochs": 0,
|
19 |
+
"c_mel": 45,
|
20 |
+
"c_kl": 1.0
|
21 |
+
},
|
22 |
+
"data": {
|
23 |
+
"training_files": "final_annotation_train.txt",
|
24 |
+
"validation_files": "final_annotation_val.txt",
|
25 |
+
"text_cleaners": [
|
26 |
+
"zh_ja_mixture_cleaners"
|
27 |
+
],
|
28 |
+
"max_wav_value": 32768.0,
|
29 |
+
"sampling_rate": 22050,
|
30 |
+
"filter_length": 1024,
|
31 |
+
"hop_length": 256,
|
32 |
+
"win_length": 1024,
|
33 |
+
"n_mel_channels": 80,
|
34 |
+
"mel_fmin": 0.0,
|
35 |
+
"mel_fmax": null,
|
36 |
+
"add_blank": true,
|
37 |
+
"n_speakers": 3,
|
38 |
+
"cleaned_text": true
|
39 |
+
},
|
40 |
+
"model": {
|
41 |
+
"inter_channels": 192,
|
42 |
+
"hidden_channels": 192,
|
43 |
+
"filter_channels": 768,
|
44 |
+
"n_heads": 2,
|
45 |
+
"n_layers": 6,
|
46 |
+
"kernel_size": 3,
|
47 |
+
"p_dropout": 0.1,
|
48 |
+
"resblock": "1",
|
49 |
+
"resblock_kernel_sizes": [
|
50 |
+
3,
|
51 |
+
7,
|
52 |
+
11
|
53 |
+
],
|
54 |
+
"resblock_dilation_sizes": [
|
55 |
+
[
|
56 |
+
1,
|
57 |
+
3,
|
58 |
+
5
|
59 |
+
],
|
60 |
+
[
|
61 |
+
1,
|
62 |
+
3,
|
63 |
+
5
|
64 |
+
],
|
65 |
+
[
|
66 |
+
1,
|
67 |
+
3,
|
68 |
+
5
|
69 |
+
]
|
70 |
+
],
|
71 |
+
"upsample_rates": [
|
72 |
+
8,
|
73 |
+
8,
|
74 |
+
2,
|
75 |
+
2
|
76 |
+
],
|
77 |
+
"upsample_initial_channel": 512,
|
78 |
+
"upsample_kernel_sizes": [
|
79 |
+
16,
|
80 |
+
16,
|
81 |
+
4,
|
82 |
+
4
|
83 |
+
],
|
84 |
+
"n_layers_q": 3,
|
85 |
+
"use_spectral_norm": false,
|
86 |
+
"gin_channels": 256
|
87 |
+
},
|
88 |
+
"speakers": {
|
89 |
+
"Hana": 0,
|
90 |
+
"specialweek": 1,
|
91 |
+
"zhongli": 2
|
92 |
+
},
|
93 |
+
"symbols": [
|
94 |
+
"_",
|
95 |
+
",",
|
96 |
+
".",
|
97 |
+
"!",
|
98 |
+
"?",
|
99 |
+
"-",
|
100 |
+
"~",
|
101 |
+
"\u2026",
|
102 |
+
"A",
|
103 |
+
"E",
|
104 |
+
"I",
|
105 |
+
"N",
|
106 |
+
"O",
|
107 |
+
"Q",
|
108 |
+
"U",
|
109 |
+
"a",
|
110 |
+
"b",
|
111 |
+
"d",
|
112 |
+
"e",
|
113 |
+
"f",
|
114 |
+
"g",
|
115 |
+
"h",
|
116 |
+
"i",
|
117 |
+
"j",
|
118 |
+
"k",
|
119 |
+
"l",
|
120 |
+
"m",
|
121 |
+
"n",
|
122 |
+
"o",
|
123 |
+
"p",
|
124 |
+
"r",
|
125 |
+
"s",
|
126 |
+
"t",
|
127 |
+
"u",
|
128 |
+
"v",
|
129 |
+
"w",
|
130 |
+
"y",
|
131 |
+
"z",
|
132 |
+
"\u0283",
|
133 |
+
"\u02a7",
|
134 |
+
"\u02a6",
|
135 |
+
"\u026f",
|
136 |
+
"\u0279",
|
137 |
+
"\u0259",
|
138 |
+
"\u0265",
|
139 |
+
"\u207c",
|
140 |
+
"\u02b0",
|
141 |
+
"`",
|
142 |
+
"\u2192",
|
143 |
+
"\u2193",
|
144 |
+
"\u2191",
|
145 |
+
" "
|
146 |
+
]
|
147 |
+
}
|
monotonic_align/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from .monotonic_align.core import maximum_path_c
|
4 |
+
|
5 |
+
|
6 |
+
def maximum_path(neg_cent, mask):
|
7 |
+
""" Cython optimized version.
|
8 |
+
neg_cent: [b, t_t, t_s]
|
9 |
+
mask: [b, t_t, t_s]
|
10 |
+
"""
|
11 |
+
device = neg_cent.device
|
12 |
+
dtype = neg_cent.dtype
|
13 |
+
neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
|
14 |
+
path = np.zeros(neg_cent.shape, dtype=np.int32)
|
15 |
+
|
16 |
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
|
17 |
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
|
18 |
+
maximum_path_c(path, neg_cent, t_t_max, t_s_max)
|
19 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
monotonic_align/core.pyx
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cimport cython
|
2 |
+
from cython.parallel import prange
|
3 |
+
|
4 |
+
|
5 |
+
@cython.boundscheck(False)
|
6 |
+
@cython.wraparound(False)
|
7 |
+
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
8 |
+
cdef int x
|
9 |
+
cdef int y
|
10 |
+
cdef float v_prev
|
11 |
+
cdef float v_cur
|
12 |
+
cdef float tmp
|
13 |
+
cdef int index = t_x - 1
|
14 |
+
|
15 |
+
for y in range(t_y):
|
16 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
17 |
+
if x == y:
|
18 |
+
v_cur = max_neg_val
|
19 |
+
else:
|
20 |
+
v_cur = value[y-1, x]
|
21 |
+
if x == 0:
|
22 |
+
if y == 0:
|
23 |
+
v_prev = 0.
|
24 |
+
else:
|
25 |
+
v_prev = max_neg_val
|
26 |
+
else:
|
27 |
+
v_prev = value[y-1, x-1]
|
28 |
+
value[y, x] += max(v_prev, v_cur)
|
29 |
+
|
30 |
+
for y in range(t_y - 1, -1, -1):
|
31 |
+
path[y, index] = 1
|
32 |
+
if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
|
33 |
+
index = index - 1
|
34 |
+
|
35 |
+
|
36 |
+
@cython.boundscheck(False)
|
37 |
+
@cython.wraparound(False)
|
38 |
+
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
39 |
+
cdef int b = paths.shape[0]
|
40 |
+
cdef int i
|
41 |
+
for i in prange(b, nogil=True):
|
42 |
+
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
monotonic_align/setup.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.core import setup
|
2 |
+
from Cython.Build import cythonize
|
3 |
+
import numpy
|
4 |
+
|
5 |
+
setup(
|
6 |
+
name = 'monotonic_align',
|
7 |
+
ext_modules = cythonize("core.pyx"),
|
8 |
+
include_dirs=[numpy.get_include()]
|
9 |
+
)
|