diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..66c2a4cafb77b81f9d8f7e65a485b841a6a347a9
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,76 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project e-mail
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at mikelei@mobvoi.com. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/LICENSE b/LICENSE
index e4b1f2274866bc8aafc6a53ba8a2478419bcda8d..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,201 @@
-MIT License
-
-Copyright (c) 2024 FunAudioLLM
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 01bdafb5ce4555638738c18550638cf4782d557c..7c9fe7ca41c11c1940d3707fcaad7ad8c9cd3af0 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,145 @@
-# CosyVoice
\ No newline at end of file
+# CosyVoice
+
+For `CosyVoice`, visit [CosyVoice repo](https://https://github.com/FunAudioLLM/CosyVoice) and [CosyVoice space](https://www.modelscope.cn/studios/iic/CosyVoice-300M).
+
+For `SenseVoice`, visit [SenseVoice repo](https://https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
+
+## Install
+
+**Clone and install**
+
+- Clone the repo
+``` sh
+git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
+# If you failed to clone submodule due to network failures, please run following command until success
+cd CosyVoice
+git submodule update --init --recursive
+```
+
+- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
+- Create Conda env:
+
+``` sh
+conda create -n cosyvoice python=3.8
+conda activate cosyvoice
+pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
+
+# If you encounter sox compatibility issues
+# ubuntu
+sudo apt-get install sox libsox-dev
+# centos
+sudo yum install sox sox-devel
+```
+
+**Model download**
+
+We strongly recommand that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `speech_kantts_ttsfrd` resource.
+
+If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
+
+``` python
+# SDK模型下载
+from modelscope import snapshot_download
+snapshot_download('speech_tts/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
+snapshot_download('speech_tts/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
+snapshot_download('speech_tts/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
+snapshot_download('speech_tts/speech_kantts_ttsfrd', local_dir='pretrained_models/speech_kantts_ttsfrd')
+```
+
+``` sh
+# git模型下载,请确保已安装git lfs
+mkdir -p pretrained_models
+git clone https://www.modelscope.cn/speech_tts/CosyVoice-300M.git pretrained_models/CosyVoice-300M
+git clone https://www.modelscope.cn/speech_tts/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
+git clone https://www.modelscope.cn/speech_tts/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
+git clone https://www.modelscope.cn/speech_tts/speech_kantts_ttsfrd.git pretrained_models/speech_kantts_ttsfrd
+```
+
+Unzip `ttsfrd` resouce and install `ttsfrd` package
+``` sh
+cd pretrained_models/speech_kantts_ttsfrd/
+unzip resource.zip -d .
+pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
+```
+
+**Basic Usage**
+
+For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
+For sft inference, please use `CosyVoice-300M-SFT` model.
+For instruct inference, please use `CosyVoice-300M-Instruct` model.
+First, add `third_party/AcademiCodec` and `third_party/Matcha-TTS` to your `PYTHONPATH`.
+
+``` sh
+export PYTHONPATH=third_party/AcademiCodec:third_party/Matcha-TTS
+```
+
+``` python
+from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.utils.file_utils import load_wav
+import torchaudio
+
+cosyvoice = CosyVoice('speech_tts/CosyVoice-300M-SFT')
+# sft usage
+print(cosyvoice.list_avaliable_spks())
+output = cosyvoice.inference_sft('你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?', '中文女')
+torchaudio.save('sft.wav', output['tts_speech'], 22050)
+
+cosyvoice = CosyVoice('speech_tts/CosyVoice-300M')
+# zero_shot usage
+prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
+output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
+torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
+# cross_lingual usage
+prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
+output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
+torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
+
+cosyvoice = CosyVoice('speech_tts/CosyVoice-300M-Instruct')
+# instruct usage
+output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的勇气与智慧。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
+torchaudio.save('instruct.wav', output['tts_speech'], 22050)
+```
+
+**Start web demo**
+
+You can use our web demo page to get familiar with CosyVoice quickly.
+We support sft/zero_shot/cross_lingual/instruct inference in web demo.
+
+Please see the demo website for details.
+
+``` python
+# change speech_tts/CosyVoice-300M-SFT for sft inference, or speech_tts/CosyVoice-300M-Instruct for instruct inference
+python3 webui.py --port 50000 --model_dir speech_tts/CosyVoice-300M
+```
+
+**Advanced Usage**
+
+For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
+You can get familiar with CosyVoice following this recipie.
+
+**Build for deployment**
+
+Optionally, if you want to use grpc for service deployment,
+you can run following steps. Otherwise, you can just ignore this step.
+
+``` sh
+cd runtime/python
+docker build -t cosyvoice:v1.0 .
+# change speech_tts/CosyVoice-300M to speech_tts/CosyVoice-300M-Instruct if you want to use instruct inference
+docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python && python3 server.py --port 50000 --max_conc 4 --model_dir speech_tts/CosyVoice-300M && sleep infinity"
+python3 client.py --port 50000 --mode
+```
+
+## Discussion & Communication
+
+You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
+
+You can also scan the QR code to join our officla Dingding chat group.
+
+
+
+## Acknowledge
+
+1. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
+2. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
+3. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
\ No newline at end of file
diff --git a/asset/dingding.png b/asset/dingding.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a644005c7b38fd64597c1eadfc6c708973e9a94
Binary files /dev/null and b/asset/dingding.png differ
diff --git a/cosyvoice/__init__.py b/cosyvoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b777fa1cba925f9786db60b7efa15dcd189adeb
--- /dev/null
+++ b/cosyvoice/bin/inference.py
@@ -0,0 +1,114 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+
+import torch
+from torch.utils.data import DataLoader
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from tqdm import tqdm
+from cosyvoice.cli.model import CosyVoiceModel
+
+from cosyvoice.dataset.dataset import Dataset
+
+def get_args():
+ parser = argparse.ArgumentParser(description='inference with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
+ parser.add_argument('--tts_text', required=True, help='tts input file')
+ parser.add_argument('--llm_model', required=True, help='llm model file')
+ parser.add_argument('--flow_model', required=True, help='flow model file')
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
+ parser.add_argument('--gpu',
+ type=int,
+ default=-1,
+ help='gpu id for this rank, -1 for cpu')
+ parser.add_argument('--mode',
+ default='sft',
+ choices=['sft', 'zero_shot'],
+ help='inference mode')
+ parser.add_argument('--result_dir', required=True, help='asr result file')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+
+ # Init cosyvoice models from configs
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f)
+
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
+
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
+
+ del configs
+ os.makedirs(args.result_dir, exist_ok=True)
+ fn = os.path.join(args.result_dir, 'wav.scp')
+ f = open(fn, 'w')
+ with torch.no_grad():
+ for batch_idx, batch in tqdm(enumerate(test_data_loader)):
+ utts = batch["utts"]
+ assert len(utts) == 1, "inference mode only support batchsize 1"
+ text = batch["text"]
+ text_token = batch["text_token"].to(device)
+ text_token_len = batch["text_token_len"].to(device)
+ tts_text = batch["tts_text"]
+ tts_index = batch["tts_index"]
+ tts_text_token = batch["tts_text_token"].to(device)
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
+ speech_token = batch["speech_token"].to(device)
+ speech_token_len = batch["speech_token_len"].to(device)
+ speech_feat = batch["speech_feat"].to(device)
+ speech_feat_len = batch["speech_feat_len"].to(device)
+ utt_embedding = batch["utt_embedding"].to(device)
+ spk_embedding = batch["spk_embedding"].to(device)
+ if args.mode == 'sft':
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
+ else:
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
+ model_output = model.inference(**model_input)
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
+ torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
+ f.write('{} {}\n'.format(tts_key, tts_fn))
+ f.flush()
+ f.close()
+ logging.info('Result wav.scp saved in {}'.format(fn))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fdc4a285ff10416e6ff39d0b8d70339d7237d7a
--- /dev/null
+++ b/cosyvoice/bin/train.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import argparse
+import datetime
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+from copy import deepcopy
+import torch
+import torch.distributed as dist
+import deepspeed
+
+from hyperpyyaml import load_hyperpyyaml
+
+from torch.distributed.elastic.multiprocessing.errors import record
+
+from cosyvoice.utils.executor import Executor
+from cosyvoice.utils.train_utils import (
+ init_distributed,
+ init_dataset_and_dataloader,
+ init_optimizer_and_scheduler,
+ init_summarywriter, save_model,
+ wrap_cuda_model, check_modify_and_save_config)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='training your network')
+ parser.add_argument('--train_engine',
+ default='torch_ddp',
+ choices=['torch_ddp', 'deepspeed'],
+ help='Engine for paralleled training')
+ parser.add_argument('--model', required=True, help='model which will be trained')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--train_data', required=True, help='train data file')
+ parser.add_argument('--cv_data', required=True, help='cv data file')
+ parser.add_argument('--checkpoint', help='checkpoint model')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--tensorboard_dir',
+ default='tensorboard',
+ help='tensorboard log dir')
+ parser.add_argument('--ddp.dist_backend',
+ dest='dist_backend',
+ default='nccl',
+ choices=['nccl', 'gloo'],
+ help='distributed backend')
+ parser.add_argument('--num_workers',
+ default=0,
+ type=int,
+ help='num of subprocess workers for reading')
+ parser.add_argument('--prefetch',
+ default=100,
+ type=int,
+ help='prefetch number')
+ parser.add_argument('--pin_memory',
+ action='store_true',
+ default=False,
+ help='Use pinned memory buffers used for reading')
+ parser.add_argument('--deepspeed.save_states',
+ dest='save_states',
+ default='model_only',
+ choices=['model_only', 'model+optimizer'],
+ help='save model/optimizer states')
+ parser.add_argument('--timeout',
+ default=30,
+ type=int,
+ help='timeout (in seconds) of cosyvoice_join. ' +
+ '30s for aishell & 300s for wenetspeech')
+ parser = deepspeed.add_config_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+@record
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides=override_dict)
+ configs['train_conf'].update(vars(args))
+
+ # Init env for ddp
+ init_distributed(args)
+
+ # Get dataset & dataloader
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
+ init_dataset_and_dataloader(args, configs)
+
+ # Do some sanity checks and save config to arsg.model_dir
+ configs = check_modify_and_save_config(args, configs)
+
+ # Tensorboard summary
+ writer = init_summarywriter(args)
+
+ # load checkpoint
+ model = configs[args.model]
+ if args.checkpoint is not None:
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
+
+ # Dispatch model from cpu to gpu
+ model = wrap_cuda_model(args, model)
+
+ # Get optimizer & scheduler
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
+
+ # Save init checkpoints
+ info_dict = deepcopy(configs['train_conf'])
+ save_model(model, 'init', info_dict)
+
+ # Get executor
+ executor = Executor()
+
+ # Start training loop
+ for epoch in range(info_dict['max_epoch']):
+ executor.epoch = epoch
+ train_dataset.set_epoch(epoch)
+ dist.barrier()
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
+ dist.destroy_process_group(group_join)
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/cli/__init__.py b/cosyvoice/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8c4482891a62df6cbac39faa88972c81f5412f
--- /dev/null
+++ b/cosyvoice/cli/cosyvoice.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import torch
+from hyperpyyaml import load_hyperpyyaml
+from modelscope import snapshot_download
+from cosyvoice.cli.frontend import CosyVoiceFrontEnd
+from cosyvoice.cli.model import CosyVoiceModel
+
+class CosyVoice:
+
+ def __init__(self, model_dir):
+ instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
+ configs = load_hyperpyyaml(f)
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ '{}/campplus.onnx'.format(model_dir),
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
+ '{}/spk2info.pt'.format(model_dir),
+ instruct,
+ configs['allowed_special'])
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/hift.pt'.format(model_dir))
+ del configs
+
+ def list_avaliable_spks(self):
+ spks = list(self.frontend.spk2info.keys())
+ return spks
+
+ def inference_sft(self, tts_text, spk_id):
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_sft(i, spk_id)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k):
+ if self.frontend.instruct is True:
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+
+ def inference_instruct(self, tts_text, spk_id, instruct_text):
+ if self.frontend.instruct is False:
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1066e8cb45ed56c8a6210404296475bd9db297e
--- /dev/null
+++ b/cosyvoice/cli/frontend.py
@@ -0,0 +1,146 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+import onnxruntime
+import torch
+import numpy as np
+import whisper
+from typing import Callable
+import torchaudio.compliance.kaldi as kaldi
+import torchaudio
+import os
+import inflect
+import ttsfrd
+from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
+
+
+class CosyVoiceFrontEnd:
+
+ def __init__(self,
+ get_tokenizer: Callable,
+ feat_extractor: Callable,
+ campplus_model: str,
+ speech_tokenizer_model: str,
+ spk2info: str = '',
+ instruct: bool = False,
+ allowed_special: str = 'all'):
+ self.tokenizer = get_tokenizer()
+ self.feat_extractor = feat_extractor
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"])
+ if os.path.exists(spk2info):
+ self.spk2info = torch.load(spk2info, map_location=self.device)
+ self.instruct = instruct
+ self.allowed_special = allowed_special
+ self.inflect_parser = inflect.engine()
+ self.frd = ttsfrd.TtsFrontendEngine()
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ assert self.frd.initialize('{}/../../pretrained_models/speech_kantts_ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
+ self.frd.set_lang_type('pinyin')
+ self.frd.enable_pinyin_mix(True)
+ self.frd.set_breakmodel_index(1)
+
+ def _extract_text_token(self, text):
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+ return text_token, text_token_len
+
+ def _extract_speech_token(self, speech):
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_token, speech_token_len
+
+ def _extract_spk_embedding(self, speech):
+ feat = kaldi.fbank(speech,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000)
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+ embedding = torch.tensor([embedding]).to(self.device)
+ return embedding
+
+ def _extract_speech_feat(self, speech):
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
+ speech_feat = speech_feat.unsqueeze(dim=0)
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_feat, speech_feat_len
+
+ def text_normalize(self, text, split=True):
+ text = text.strip()
+ if contains_chinese(text):
+ text = self.frd.get_frd_extra_info(text, 'input').replace("\n", "")
+ text = replace_blank(text)
+ text = replace_corner_mark(text)
+ text = text.replace(".", "、")
+ text = text.replace(" - ", ",")
+ text = remove_bracket(text)
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ token_min_n=60, merge_len=20,
+ comma_split=False)]
+ else:
+ text = spell_out_number(text, self.inflect_parser)
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+ token_min_n=60, merge_len=20,
+ comma_split=False)]
+ if split is False:
+ return text
+ return texts
+
+ def frontend_sft(self, tts_text, spk_id):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ embedding = self.spk2info[spk_id]['embedding']
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
+ # in cross lingual mode, we remove prompt in llm
+ del model_input['prompt_text']
+ del model_input['prompt_text_len']
+ del model_input['llm_prompt_speech_token']
+ del model_input['llm_prompt_speech_token_len']
+ return model_input
+
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
+ model_input = self.frontend_sft(tts_text, spk_id)
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
+ del model_input['llm_embedding']
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '')
+ model_input['prompt_text'] = instruct_text_token
+ model_input['prompt_text_len'] = instruct_text_token_len
+ return model_input
diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..98f19b2659fc5680e18e57c66aef3a78dc5de5ed
--- /dev/null
+++ b/cosyvoice/cli/model.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+class CosyVoiceModel:
+
+ def __init__(self,
+ llm: torch.nn.Module,
+ flow: torch.nn.Module,
+ hift: torch.nn.Module):
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.llm = llm
+ self.flow = flow
+ self.hift = hift
+
+ def load(self, llm_model, flow_model, hift_model):
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
+ self.llm.to(self.device).eval()
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
+ self.flow.to(self.device).eval()
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
+ self.hift.to(self.device).eval()
+
+ def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
+ text_len=text_len.to(self.device),
+ prompt_text=prompt_text.to(self.device),
+ prompt_text_len=prompt_text_len.to(self.device),
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
+ embedding=llm_embedding.to(self.device),
+ beam_size=1,
+ sampling=25,
+ max_token_text_ratio=30,
+ min_token_text_ratio=3)
+ tts_mel = self.flow.inference(token=tts_speech_token,
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
+ prompt_token=flow_prompt_speech_token.to(self.device),
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+ prompt_feat=prompt_speech_feat.to(self.device),
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
+ embedding=flow_embedding.to(self.device))
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
+ return {'tts_speech': tts_speech}
diff --git a/cosyvoice/dataset/__init__.py b/cosyvoice/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/dataset/dataset.py b/cosyvoice/dataset/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..431fae124debeabfbb7c7742317bddcf7984e91e
--- /dev/null
+++ b/cosyvoice/dataset/dataset.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import json
+import math
+from functools import partial
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import IterableDataset
+from cosyvoice.utils.file_utils import read_lists, read_json_lists
+
+
+class Processor(IterableDataset):
+
+ def __init__(self, source, f, *args, **kw):
+ assert callable(f)
+ self.source = source
+ self.f = f
+ self.args = args
+ self.kw = kw
+
+ def set_epoch(self, epoch):
+ self.source.set_epoch(epoch)
+
+ def __iter__(self):
+ """ Return an iterator over the source dataset processed by the
+ given processor.
+ """
+ assert self.source is not None
+ assert callable(self.f)
+ return self.f(iter(self.source), *self.args, **self.kw)
+
+ def apply(self, f):
+ assert callable(f)
+ return Processor(self, f, *self.args, **self.kw)
+
+
+class DistributedSampler:
+
+ def __init__(self, shuffle=True, partition=True):
+ self.epoch = -1
+ self.update()
+ self.shuffle = shuffle
+ self.partition = partition
+
+ def update(self):
+ assert dist.is_available()
+ if dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ else:
+ self.rank = 0
+ self.world_size = 1
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is None:
+ self.worker_id = 0
+ self.num_workers = 1
+ else:
+ self.worker_id = worker_info.id
+ self.num_workers = worker_info.num_workers
+ return dict(rank=self.rank,
+ world_size=self.world_size,
+ worker_id=self.worker_id,
+ num_workers=self.num_workers)
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def sample(self, data):
+ """ Sample data according to rank/world_size/num_workers
+
+ Args:
+ data(List): input data list
+
+ Returns:
+ List: data list after sample
+ """
+ data = list(range(len(data)))
+ # force datalist even
+ if self.partition:
+ if self.shuffle:
+ random.Random(self.epoch).shuffle(data)
+ if len(data) < self.world_size:
+ data = data * math.ceil(self.world_size / len(data))
+ data = data[:self.world_size]
+ data = data[self.rank::self.world_size]
+ if len(data) < self.num_workers:
+ data = data * math.ceil(self.num_workers / len(data))
+ data = data[:self.num_workers]
+ data = data[self.worker_id::self.num_workers]
+ return data
+
+
+class DataList(IterableDataset):
+
+ def __init__(self, lists, shuffle=True, partition=True):
+ self.lists = lists
+ self.sampler = DistributedSampler(shuffle, partition)
+
+ def set_epoch(self, epoch):
+ self.sampler.set_epoch(epoch)
+
+ def __iter__(self):
+ sampler_info = self.sampler.update()
+ indexes = self.sampler.sample(self.lists)
+ for index in indexes:
+ data = dict(src=self.lists[index])
+ data.update(sampler_info)
+ yield data
+
+
+def Dataset(data_list_file,
+ data_pipeline,
+ mode='train',
+ shuffle=True,
+ partition=True,
+ tts_file='',
+ prompt_utt2data=''):
+ """ Construct dataset from arguments
+
+ We have two shuffle stage in the Dataset. The first is global
+ shuffle at shards tar/raw file level. The second is global shuffle
+ at training samples level.
+
+ Args:
+ data_type(str): raw/shard
+ tokenizer (BaseTokenizer): tokenizer to tokenize
+ partition(bool): whether to do data partition in terms of rank
+ """
+ assert mode in ['train', 'inference']
+ lists = read_lists(data_list_file)
+ if mode == 'inference':
+ with open(tts_file) as f:
+ tts_data = json.load(f)
+ utt2lists = read_json_lists(prompt_utt2data)
+ # filter unnecessary file in inference mode
+ lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
+ dataset = DataList(lists,
+ shuffle=shuffle,
+ partition=partition)
+ if mode == 'inference':
+ # map partial arg tts_data in inference mode
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
+ for func in data_pipeline:
+ dataset = Processor(dataset, func, mode=mode)
+ return dataset
diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb99f3cfcac2c8e00c57e2ccda505e86d7802167
--- /dev/null
+++ b/cosyvoice/dataset/processor.py
@@ -0,0 +1,366 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import random
+
+import pyarrow.parquet as pq
+from io import BytesIO
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+import torch.nn.functional as F
+
+torchaudio.set_audio_backend('soundfile')
+torchaudio.utils.sox_utils.set_buffer_size(16500)
+
+AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
+
+
+def parquet_opener(data, mode='train', tts_data={}):
+ """ Give url or local file, return file descriptor
+ Inplace operation.
+
+ Args:
+ data(Iterable[str]): url or local file list
+
+ Returns:
+ Iterable[{src, stream}]
+ """
+ for sample in data:
+ assert 'src' in sample
+ url = sample['src']
+ try:
+ df = pq.read_table(url).to_pandas()
+ for i in range(len(df)):
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
+ continue
+ sample.update(dict(df.loc[i]))
+ if mode == 'train':
+ # NOTE do not return sample directly, must initialize a new dict
+ yield {**sample}
+ else:
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
+ yield {**sample, 'tts_index': index, 'tts_text': text}
+ except Exception as ex:
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
+
+def filter(data,
+ max_length=10240,
+ min_length=10,
+ token_max_length=200,
+ token_min_length=1,
+ min_output_input_ratio=0.0005,
+ max_output_input_ratio=1,
+ mode='train'):
+ """ Filter sample according to feature and label length
+ Inplace operation.
+
+ Args::
+ data: Iterable[{key, wav, label, sample_rate}]
+ max_length: drop utterance which is greater than max_length(10ms)
+ min_length: drop utterance which is less than min_length(10ms)
+ token_max_length: drop utterance which is greater than
+ token_max_length, especially when use char unit for
+ english modeling
+ token_min_length: drop utterance which is
+ less than token_max_length
+ min_output_input_ratio: minimal ration of
+ token_length / feats_length(10ms)
+ max_output_input_ratio: maximum ration of
+ token_length / feats_length(10ms)
+
+ Returns:
+ Iterable[{key, wav, label, sample_rate}]
+ """
+ for sample in data:
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
+ del sample['audio_data']
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
+ if num_frames < min_length:
+ continue
+ if num_frames > max_length:
+ continue
+ if len(sample['text_token']) < token_min_length:
+ continue
+ if len(sample['text_token']) > token_max_length:
+ continue
+ if len(sample['speech_token']) == 0:
+ continue
+ if num_frames != 0:
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
+ continue
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
+ continue
+ yield sample
+
+
+def resample(data, resample_rate=22050, mode='train'):
+ """ Resample data.
+ Inplace operation.
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+ resample_rate: target resample rate
+
+ Returns:
+ Iterable[{key, wav, label, sample_rate}]
+ """
+ for sample in data:
+ assert 'sample_rate' in sample
+ assert 'speech' in sample
+ sample_rate = sample['sample_rate']
+ waveform = sample['speech']
+ if sample_rate != resample_rate:
+ if sample_rate < resample_rate:
+ continue
+ sample['sample_rate'] = resample_rate
+ sample['speech'] = torchaudio.transforms.Resample(
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
+ max_val = sample['speech'].abs().max()
+ if max_val > 1:
+ sample['speech'] /= max_val
+ yield sample
+
+
+def compute_fbank(data,
+ feat_extractor,
+ mode='train'):
+ """ Extract fbank
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+ for sample in data:
+ assert 'sample_rate' in sample
+ assert 'speech' in sample
+ assert 'utt' in sample
+ assert 'text_token' in sample
+ waveform = sample['speech']
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
+ sample['speech_feat'] = mat
+ del sample['speech']
+ yield sample
+
+
+def parse_embedding(data, normalize, mode='train'):
+ """ Parse utt_embedding/spk_embedding
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+ for sample in data:
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
+ sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
+ if normalize:
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
+ yield sample
+
+
+def tokenize(data, get_tokenizer, allowed_special, mode='train'):
+ """ Decode text to chars or BPE
+ Inplace operation
+
+ Args:
+ data: Iterable[{key, wav, txt, sample_rate}]
+
+ Returns:
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
+ """
+ tokenizer = get_tokenizer()
+ for sample in data:
+ assert 'text' in sample
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
+ if mode == 'inference':
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
+ yield sample
+
+
+def shuffle(data, shuffle_size=10000, mode='train'):
+ """ Local shuffle the data
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ shuffle_size: buffer size for shuffle
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+ buf = []
+ for sample in data:
+ buf.append(sample)
+ if len(buf) >= shuffle_size:
+ random.shuffle(buf)
+ for x in buf:
+ yield x
+ buf = []
+ # The sample left over
+ random.shuffle(buf)
+ for x in buf:
+ yield x
+
+
+def sort(data, sort_size=500, mode='train'):
+ """ Sort the data by feature length.
+ Sort is used after shuffle and before batch, so we can group
+ utts with similar lengths into a batch, and `sort_size` should
+ be less than `shuffle_size`
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ sort_size: buffer size for sort
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+
+ buf = []
+ for sample in data:
+ buf.append(sample)
+ if len(buf) >= sort_size:
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
+ for x in buf:
+ yield x
+ buf = []
+ # The sample left over
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
+ for x in buf:
+ yield x
+
+
+def static_batch(data, batch_size=16):
+ """ Static batch the data by `batch_size`
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ batch_size: batch size
+
+ Returns:
+ Iterable[List[{key, feat, label}]]
+ """
+ buf = []
+ for sample in data:
+ buf.append(sample)
+ if len(buf) >= batch_size:
+ yield buf
+ buf = []
+ if len(buf) > 0:
+ yield buf
+
+
+def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
+ """ Dynamic batch the data until the total frames in batch
+ reach `max_frames_in_batch`
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ max_frames_in_batch: max_frames in one batch
+
+ Returns:
+ Iterable[List[{key, feat, label}]]
+ """
+ buf = []
+ longest_frames = 0
+ for sample in data:
+ assert 'speech_feat' in sample
+ assert isinstance(sample['speech_feat'], torch.Tensor)
+ new_sample_frames = sample['speech_feat'].size(0)
+ longest_frames = max(longest_frames, new_sample_frames)
+ frames_after_padding = longest_frames * (len(buf) + 1)
+ if frames_after_padding > max_frames_in_batch:
+ yield buf
+ buf = [sample]
+ longest_frames = new_sample_frames
+ else:
+ buf.append(sample)
+ if len(buf) > 0:
+ yield buf
+
+
+def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
+ """ Wrapper for static/dynamic batch
+ """
+ if mode == 'inference':
+ return static_batch(data, 1)
+ else:
+ if batch_type == 'static':
+ return static_batch(data, batch_size)
+ elif batch_type == 'dynamic':
+ return dynamic_batch(data, max_frames_in_batch)
+ else:
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
+
+
+def padding(data, mode='train'):
+ """ Padding the data into training data
+
+ Args:
+ data: Iterable[List[{key, feat, label}]]
+
+ Returns:
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
+ """
+ for sample in data:
+ assert isinstance(sample, list)
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
+ dtype=torch.int32)
+ order = torch.argsort(speech_feat_len, descending=True)
+
+ utts = [sample[i]['utt'] for i in order]
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
+ speech_token = pad_sequence(speech_token,
+ batch_first=True,
+ padding_value=0)
+ speech_feat = [sample[i]['speech_feat'] for i in order]
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
+ speech_feat = pad_sequence(speech_feat,
+ batch_first=True,
+ padding_value=0)
+ text = [sample[i]['text'] for i in order]
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
+ batch = {
+ "utts": utts,
+ "speech_token": speech_token,
+ "speech_token_len": speech_token_len,
+ "speech_feat": speech_feat,
+ "speech_feat_len": speech_feat_len,
+ "text": text,
+ "text_token": text_token,
+ "text_token_len": text_token_len,
+ "utt_embedding": utt_embedding,
+ "spk_embedding": spk_embedding,
+ }
+ if mode == 'inference':
+ tts_text = [sample[i]['tts_text'] for i in order]
+ tts_index = [sample[i]['tts_index'] for i in order]
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
+ batch.update({'tts_text': tts_text,
+ 'tts_index': tts_index,
+ 'tts_text_token': tts_text_token,
+ 'tts_text_token_len': tts_text_token_len})
+ yield batch
diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..43492799390b44a2843bc53604603842754799f9
--- /dev/null
+++ b/cosyvoice/flow/decoder.py
@@ -0,0 +1,222 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from einops import pack, rearrange, repeat
+from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
+from matcha.models.components.transformer import BasicTransformerBlock
+
+
+class ConditionalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ channels=(256, 256),
+ dropout=0.05,
+ attention_head_dim=64,
+ n_blocks=1,
+ num_mid_blocks=2,
+ num_heads=4,
+ act_fn="snake",
+ ):
+ """
+ This decoder requires an input with the same shape of the target. So, if your text content
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
+ """
+ super().__init__()
+ channels = tuple(channels)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
+ time_embed_dim = channels[0] * 4
+ self.time_mlp = TimestepEmbedding(
+ in_channels=in_channels,
+ time_embed_dim=time_embed_dim,
+ act_fn="silu",
+ )
+ self.down_blocks = nn.ModuleList([])
+ self.mid_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ output_channel = in_channels
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_last = i == len(channels) - 1
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ downsample = (
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
+
+ for i in range(num_mid_blocks):
+ input_channel = channels[-1]
+ out_channels = channels[-1]
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
+
+ channels = channels[::-1] + (channels[0],)
+ for i in range(len(channels) - 1):
+ input_channel = channels[i] * 2
+ output_channel = channels[i + 1]
+ is_last = i == len(channels) - 2
+ resnet = ResnetBlock1D(
+ dim=input_channel,
+ dim_out=output_channel,
+ time_emb_dim=time_embed_dim,
+ )
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ upsample = (
+ Upsample1D(output_channel, use_conv_transpose=True)
+ if not is_last
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
+ self.final_block = Block1D(channels[-1], channels[-1])
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
+ self.initialize_weights()
+
+
+ def initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
+ """Forward pass of the UNet1DConditional model.
+
+ Args:
+ x (torch.Tensor): shape (batch_size, in_channels, time)
+ mask (_type_): shape (batch_size, 1, time)
+ t (_type_): shape (batch_size)
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
+ cond (_type_, optional): placeholder for future use. Defaults to None.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ t = self.time_embeddings(t)
+ t = self.time_mlp(t)
+
+ x = pack([x, mu], "b * t")[0]
+
+ if spks is not None:
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
+ x = pack([x, spks], "b * t")[0]
+ if cond is not None:
+ x = pack([x, cond], "b * t")[0]
+
+ hiddens = []
+ masks = [mask]
+ for resnet, transformer_blocks, downsample in self.down_blocks:
+ mask_down = masks[-1]
+ x = resnet(x, mask_down, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+ hiddens.append(x) # Save hidden states for skip connections
+ x = downsample(x * mask_down)
+ masks.append(mask_down[:, :, ::2])
+ masks = masks[:-1]
+ mask_mid = masks[-1]
+
+ for resnet, transformer_blocks in self.mid_blocks:
+ x = resnet(x, mask_mid, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+
+ for resnet, transformer_blocks, upsample in self.up_blocks:
+ mask_up = masks.pop()
+ skip = hiddens.pop()
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
+ x = resnet(x, mask_up, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+ x = upsample(x * mask_up)
+ x = self.final_block(x, mask_up)
+ output = self.final_proj(x * mask_up)
+ return output * mask
diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0dbcd0477784deb0ab812f95626eb635049a8ab
--- /dev/null
+++ b/cosyvoice/flow/flow.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Dict, Optional
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from omegaconf import DictConfig
+from cosyvoice.utils.mask import make_pad_mask
+
+
+class MaskedDiffWithXvec(torch.nn.Module):
+ def __init__(self,
+ input_size: int = 512,
+ output_size: int = 80,
+ spk_embed_dim: int = 192,
+ output_type: str = "mel",
+ vocab_size: int = 4096,
+ input_frame_rate: int = 50,
+ only_mask_loss: bool = True,
+ encoder: torch.nn.Module = None,
+ length_regulator: torch.nn.Module = None,
+ decoder: torch.nn.Module = None,
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.decoder_conf = decoder_conf
+ self.mel_feat_conf = mel_feat_conf
+ self.vocab_size = vocab_size
+ self.output_type = output_type
+ self.input_frame_rate = input_frame_rate
+ logging.info(f"input frame rate={self.input_frame_rate}")
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
+ self.encoder = encoder
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
+ self.decoder = decoder
+ self.length_regulator = length_regulator
+ self.only_mask_loss = only_mask_loss
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ token = batch['speech_token'].to(device)
+ token_len = batch['speech_token_len'].to(device)
+ feat = batch['speech_feat'].to(device)
+ feat_len = batch['speech_feat_len'].to(device)
+ embedding = batch['utt_embedding'].to(device)
+
+ # xvec projection
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+
+ # concat text and prompt_text
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+ # text encode
+ h, h_lengths = self.encoder(token, token_len)
+ h = self.encoder_proj(h)
+ h, h_lengths = self.length_regulator(h, feat_len)
+
+ # get conditions
+ conds = torch.zeros(feat.shape, device=token.device)
+ conds = conds.transpose(1, 2)
+
+ mask = (~make_pad_mask(feat_len)).to(h)
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
+ loss, _ = self.decoder.compute_loss(
+ feat.transpose(1, 2).contiguous(),
+ mask.unsqueeze(1),
+ h.transpose(1, 2).contiguous(),
+ embedding,
+ cond=conds
+ )
+ return {'loss': loss}
+
+ @torch.inference_mode()
+ def inference(self,
+ token,
+ token_len,
+ prompt_token,
+ prompt_token_len,
+ prompt_feat,
+ prompt_feat_len,
+ embedding):
+ assert token.shape[0] == 1
+ # xvec projection
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+
+ # concat text and prompt_text
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+ # text encode
+ h, h_lengths = self.encoder(token, token_len)
+ h = self.encoder_proj(h)
+ feat_len = (token_len / 50 * 22050 / 256).int()
+ h, h_lengths = self.length_regulator(h, feat_len)
+
+ # get conditions
+ conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
+ if prompt_feat.shape[1] != 0:
+ for i, j in enumerate(prompt_feat_len):
+ conds[i, :j] = prompt_feat[i]
+ conds = conds.transpose(1, 2)
+
+ mask = (~make_pad_mask(feat_len)).to(h)
+ feat = self.decoder(
+ mu=h.transpose(1, 2).contiguous(),
+ mask=mask.unsqueeze(1),
+ spks=embedding,
+ cond=conds,
+ n_timesteps=10
+ )
+ if prompt_feat.shape[1] != 0:
+ feat = feat[:, :, prompt_feat.shape[1]:]
+ return feat
diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py
new file mode 100755
index 0000000000000000000000000000000000000000..6a8985ff34dff289c4193c1778db8286a5a6928f
--- /dev/null
+++ b/cosyvoice/flow/flow_matching.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn.functional as F
+from matcha.models.components.flow_matching import BASECFM
+
+class ConditionalCFM(BASECFM):
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
+ super().__init__(
+ n_feats=in_channels,
+ cfm_params=cfm_params,
+ n_spks=n_spks,
+ spk_emb_dim=spk_emb_dim,
+ )
+ self.t_scheduler = cfm_params.t_scheduler
+ self.training_cfg_rate = cfm_params.training_cfg_rate
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
+ # Just change the architecture of the estimator here
+ self.estimator = estimator
+
+ @torch.inference_mode()
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ z = torch.randn_like(mu) * temperature
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+ if self.t_scheduler == 'cosine':
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
+
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+ """
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
+ # Or in future might add like a return_all_steps flag
+ sol = []
+
+ for step in range(1, len(t_span)):
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
+ # Classifier-Free Guidance inference introduced in VoiceBox
+ if self.inference_cfg_rate > 0:
+ cfg_dphi_dt = self.estimator(
+ x, mask,
+ torch.zeros_like(mu), t,
+ torch.zeros_like(spks) if spks is not None else None,
+ torch.zeros_like(cond)
+ )
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
+ self.inference_cfg_rate * cfg_dphi_dt)
+ x = x + dt * dphi_dt
+ t = t + dt
+ sol.append(x)
+ if step < len(t_span) - 1:
+ dt = t_span[step + 1] - t
+
+ return sol[-1]
+
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
+ """Computes diffusion loss
+
+ Args:
+ x1 (torch.Tensor): Target
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): target mask
+ shape: (batch_size, 1, mel_timesteps)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+
+ Returns:
+ loss: conditional flow matching loss
+ y: conditional flow
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ b, _, t = mu.shape
+
+ # random timestep
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
+ if self.t_scheduler == 'cosine':
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
+ # sample noise p(x_0)
+ z = torch.randn_like(x1)
+
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+ u = x1 - (1 - self.sigma_min) * z
+
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
+ return loss, y
diff --git a/cosyvoice/flow/length_regulator.py b/cosyvoice/flow/length_regulator.py
new file mode 100755
index 0000000000000000000000000000000000000000..622f29aaccc44d8e8cce23ecab7b086ebb853fde
--- /dev/null
+++ b/cosyvoice/flow/length_regulator.py
@@ -0,0 +1,49 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+import torch.nn as nn
+from torch.nn import functional as F
+from cosyvoice.utils.mask import make_pad_mask
+
+
+class InterpolateRegulator(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ sampling_ratios: Tuple,
+ out_channels: int = None,
+ groups: int = 1,
+ ):
+ super().__init__()
+ self.sampling_ratios = sampling_ratios
+ out_channels = out_channels or channels
+ model = nn.ModuleList([])
+ if len(sampling_ratios) > 0:
+ for _ in sampling_ratios:
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
+ norm = nn.GroupNorm(groups, channels)
+ act = nn.Mish()
+ model.extend([module, norm, act])
+ model.append(
+ nn.Conv1d(channels, out_channels, 1, 1)
+ )
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x, ylens=None):
+ # x in (B, T, D)
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+ out = self.model(x).transpose(1, 2).contiguous()
+ olens = ylens
+ return out * mask, olens
diff --git a/cosyvoice/hifigan/f0_predictor.py b/cosyvoice/hifigan/f0_predictor.py
new file mode 100755
index 0000000000000000000000000000000000000000..36b85f4ed90c3a412cb179f49ccb471132a86550
--- /dev/null
+++ b/cosyvoice/hifigan/f0_predictor.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+
+
+class ConvRNNF0Predictor(nn.Module):
+ def __init__(self,
+ num_class: int = 1,
+ in_channels: int = 80,
+ cond_channels: int = 512
+ ):
+ super().__init__()
+
+ self.num_class = num_class
+ self.condnet = nn.Sequential(
+ weight_norm(
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ )
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.condnet(x)
+ x = x.transpose(1, 2)
+ return torch.abs(self.classifier(x).squeeze(-1))
diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa8c7ee2669b6e76e6ea4f3b531946f45019fba7
--- /dev/null
+++ b/cosyvoice/hifigan/generator.py
@@ -0,0 +1,391 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""HIFI-GAN"""
+
+import typing as tp
+import numpy as np
+from scipy.signal import get_window
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn import ConvTranspose1d
+from torch.nn.utils import remove_weight_norm
+from torch.nn.utils import weight_norm
+from torch.distributions.uniform import Uniform
+
+from cosyvoice.transformer.activation import Snake
+from academicodec.utils import get_padding
+from academicodec.utils import init_weights
+
+
+"""hifigan based generator implementation.
+
+This code is modified from https://github.com/jik876/hifi-gan
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
+ https://github.com/NVIDIA/BigVGAN
+
+"""
+class ResBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN/BigVGAN."""
+ def __init__(
+ self,
+ channels: int = 512,
+ kernel_size: int = 3,
+ dilations: tp.List[int] = [1, 3, 5],
+ ):
+ super(ResBlock, self).__init__()
+ self.convs1 = nn.ModuleList()
+ self.convs2 = nn.ModuleList()
+
+ for dilation in dilations:
+ self.convs1.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ padding=get_padding(kernel_size, dilation)
+ )
+ )
+ )
+ self.convs2.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)
+ )
+ )
+ )
+ self.convs1.apply(init_weights)
+ self.convs2.apply(init_weights)
+ self.activations1 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs1))
+ ])
+ self.activations2 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs2))
+ ])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for idx in range(len(self.convs1)):
+ xt = self.activations1[idx](x)
+ xt = self.convs1[idx](xt)
+ xt = self.activations2[idx](xt)
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for idx in range(len(self.convs1)):
+ remove_weight_norm(self.convs1[idx])
+ remove_weight_norm(self.convs2[idx])
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ @torch.no_grad()
+ def forward(self, f0):
+ """
+ :param f0: [B, 1, sample_len], Hz
+ :return: [B, 1, sample_len]
+ """
+
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
+ for i in range(self.harmonic_num + 1):
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
+
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
+ u_dist = Uniform(low=-np.pi, high=np.pi)
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
+ phase_vec[:, 0, :] = 0
+
+ # generate sine waveforms
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
+
+ # generate uv signal
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
+ sine_wavs = sine_wavs.transpose(1, 2)
+ uv = uv.transpose(1, 2)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+
+class HiFTGenerator(nn.Module):
+ """
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
+ https://arxiv.org/abs/2309.09493
+ """
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ nb_harmonics: int = 8,
+ sampling_rate: int = 22050,
+ nsf_alpha: float = 0.1,
+ nsf_sigma: float = 0.003,
+ nsf_voiced_threshold: float = 10,
+ upsample_rates: tp.List[int] = [8, 8],
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
+ lrelu_slope: float = 0.1,
+ audio_limit: float = 0.99,
+ f0_predictor: torch.nn.Module = None,
+ ):
+ super(HiFTGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.nb_harmonics = nb_harmonics
+ self.sampling_rate = sampling_rate
+ self.istft_params = istft_params
+ self.lrelu_slope = lrelu_slope
+ self.audio_limit = audio_limit
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=sampling_rate,
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+ harmonic_num=nb_harmonics,
+ sine_amp=nsf_alpha,
+ add_noise_std=nsf_sigma,
+ voiced_threshod=nsf_voiced_threshold)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
+
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
+ )
+
+ # Up
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ # Down
+ self.source_downs = nn.ModuleList()
+ self.source_resblocks = nn.ModuleList()
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
+ downsample_cum_rates = np.cumprod(downsample_rates)
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
+ source_resblock_dilation_sizes)):
+ if u == 1:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
+ )
+ else:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
+ )
+
+ self.source_resblocks.append(
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+ self.f0_predictor = f0_predictor
+
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+
+ har_source, _, _ = self.m_source(f0)
+ return har_source.transpose(1, 2)
+
+ def _stft(self, x):
+ spec = torch.stft(
+ x,
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
+ return_complex=True)
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
+ return spec[..., 0], spec[..., 1]
+
+ def _istft(self, magnitude, phase):
+ magnitude = torch.clip(magnitude, max=1e2)
+ real = magnitude * torch.cos(phase)
+ img = magnitude * torch.sin(phase)
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+ return inverse_transform
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ f0 = self.f0_predictor(x)
+ s = self._f02source(f0)
+
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = self.ups[i](x)
+
+ if i == self.num_upsamples - 1:
+ x = self.reflection_pad(x)
+
+ # fusion
+ si = self.source_downs[i](s_stft)
+ si = self.source_resblocks[i](si)
+ x = x + si
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
+
+ x = self._istft(magnitude, phase)
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ self.source_module.remove_weight_norm()
+ for l in self.source_downs:
+ remove_weight_norm(l)
+ for l in self.source_resblocks:
+ l.remove_weight_norm()
+
+ @torch.inference_mode()
+ def inference(self, mel: torch.Tensor) -> torch.Tensor:
+ return self.forward(x=mel)
diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c22effc89765e11304eca5acbd87eb55f76634
--- /dev/null
+++ b/cosyvoice/llm/llm.py
@@ -0,0 +1,206 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Optional, Union
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence, unpad_sequence
+from cosyvoice.utils.common import IGNORE_ID
+from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
+from cosyvoice.utils.common import th_accuracy
+
+
+class TransformerLM(torch.nn.Module):
+ def __init__(
+ self,
+ text_encoder_input_size: int,
+ llm_input_size: int,
+ llm_output_size: int,
+ text_token_size: int,
+ speech_token_size: int,
+ text_encoder: torch.nn.Module,
+ llm: torch.nn.Module,
+ length_normalized_loss: bool = True,
+ lsm_weight: float = 0.0,
+ spk_embed_dim: int = 192,
+ ):
+ super().__init__()
+ self.llm_input_size = llm_input_size
+ self.speech_token_size = speech_token_size
+ # 1. build text token inputs related modules
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
+ self.text_encoder = text_encoder
+ self.text_encoder_affine_layer = nn.Linear(
+ self.text_encoder.output_size(),
+ llm_input_size
+ )
+
+ # 2. build speech token language model related modules
+ self.sos_eos = 0
+ self.task_id = 1
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
+ self.llm = llm
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
+ self.criterion_ce = LabelSmoothingLoss(
+ size=speech_token_size + 1,
+ padding_idx=IGNORE_ID,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ # 3. [Optional] build speech token related modules
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
+
+ def encode(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ):
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
+ return encoder_out, encoder_out_lens
+
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
+ return lm_input, lm_input_len
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Args:
+ text: (B, L, D)
+ text_lengths: (B,)
+ audio: (B, T, N) or (B, T)
+ audio_lengths: (B,)
+ """
+ text_token = batch['text_token'].to(device)
+ text_token_len = batch['text_token_len'].to(device)
+ speech_token = batch['speech_token'].to(device)
+ speech_token_len = batch['speech_token_len'].to(device)
+ embedding = batch['utt_embedding'].to(device)
+
+ # 1. prepare llm_target
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
+
+ # 1. encode text_token
+ text_token = self.text_embedding(text_token)
+ text_token, text_token_len = self.encode(text_token, text_token_len)
+
+ # 2. embedding projection
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+ embedding = embedding.unsqueeze(1)
+
+ # 3. eos and task_id
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
+ # 4. encode speech_token
+ speech_token = self.speech_embedding(speech_token)
+
+ # 5. unpad and pad
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
+
+ # 6. run lm forward
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
+ logits = self.llm_decoder(lm_output)
+ loss = self.criterion_ce(logits, lm_target)
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
+ return {'loss': loss, 'acc': acc}
+
+ def sampling_ids(
+ self,
+ weighted_scores: torch.Tensor,
+ sampling: Union[bool, int, float] = True,
+ beam_size: int = 1,
+ ignore_eos: bool = True,
+ ):
+ while True:
+ prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
+ top_ids = prob.multinomial(beam_size, replacement=True)
+ top_ids = indices[top_ids]
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
+ break
+ return top_ids
+
+ @torch.inference_mode()
+ def inference(
+ self,
+ text: torch.Tensor,
+ text_len: torch.Tensor,
+ prompt_text: torch.Tensor,
+ prompt_text_len: torch.Tensor,
+ prompt_speech_token: torch.Tensor,
+ prompt_speech_token_len: torch.Tensor,
+ embedding: torch.Tensor,
+ beam_size: int = 1,
+ sampling: int = 25,
+ max_token_text_ratio: float = 20,
+ min_token_text_ratio: float = 2,
+ ) -> torch.Tensor:
+ device = text.device
+ text = torch.concat([prompt_text, text], dim=1)
+ text_len += prompt_text_len
+ text = self.text_embedding(text)
+
+ # 1. encode text
+ text, text_len = self.encode(text, text_len)
+
+ # 2. encode embedding
+ if embedding.shape[0] != 0:
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+ embedding = embedding.unsqueeze(dim=1)
+ else:
+ embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
+
+ # 3. concat llm_input
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+ if prompt_speech_token_len != 0:
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
+ else:
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
+
+ # 4. cal min/max_length
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
+
+ # 5. step by step decode
+ out_tokens = []
+ offset = 0
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
+ for i in range(max_len):
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
+ if top_ids == self.speech_token_size:
+ break
+ out_tokens.append(top_ids)
+ offset += lm_input.size(1)
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+ return torch.tensor([out_tokens], dtype=torch.int64, device=device)
diff --git a/cosyvoice/transformer/__init__.py b/cosyvoice/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/transformer/activation.py b/cosyvoice/transformer/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cea54816385d3b6585ccc2417bc71630d578177
--- /dev/null
+++ b/cosyvoice/transformer/activation.py
@@ -0,0 +1,84 @@
+# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
+# 2020 Northwestern Polytechnical University (Pengcheng Guo)
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Swish() activation function for Conformer."""
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return Swish activation function."""
+ return x * torch.sigmoid(x)
+
+
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
diff --git a/cosyvoice/transformer/attention.py b/cosyvoice/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..78a33fe60111a6941dc419d17f5109fbd58e0851
--- /dev/null
+++ b/cosyvoice/transformer/attention.py
@@ -0,0 +1,326 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Multi-Head Attention layer definition."""
+
+import math
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor, size
+ (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
+ # 1st chunk to ease the onnx export.]
+ # 2. pytorch training
+ if mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float('inf'))
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0) # (batch, head, time1, time2)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
+ # 1. onnx(16/-1, -1/-1, 16/0)
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
+ self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ Wenet.
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask), new_cache
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x):
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+ return x
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
+ if matrix_ac.shape != matrix_bd.shape:
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask), new_cache
diff --git a/cosyvoice/transformer/convolution.py b/cosyvoice/transformer/convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d96149154776000991a681a666fbe55e562fe
--- /dev/null
+++ b/cosyvoice/transformer/convolution.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""ConvolutionModule definition."""
+
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ super().__init__()
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+
+ assert norm in ['batch_norm', 'layer_norm']
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
+ else:
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert (x.size(2) > self.lorder)
+ new_cache = x[:, :, -self.lorder:]
+ else:
+ # It's better we just return None if no cache is required,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
diff --git a/cosyvoice/transformer/decoder.py b/cosyvoice/transformer/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..961c875eab519f7a9e8a6e56720dc878b7852372
--- /dev/null
+++ b/cosyvoice/transformer/decoder.py
@@ -0,0 +1,396 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Decoder definition."""
+from typing import Tuple, List, Optional
+
+import torch
+import torch.utils.checkpoint as ckpt
+import logging
+
+from cosyvoice.transformer.decoder_layer import DecoderLayer
+from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from cosyvoice.utils.class_utils import (
+ COSYVOICE_EMB_CLASSES,
+ COSYVOICE_ATTENTION_CLASSES,
+ COSYVOICE_ACTIVATION_CLASSES,
+)
+from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
+
+
+class TransformerDecoder(torch.nn.Module):
+ """Base class of Transfomer decoder module.
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the hidden units number of position-wise feedforward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before:
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ src_attention: if false, encoder-decoder cross attention is not
+ applied, such as CIF model
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ gradient_checkpointing: rerunning a forward-pass segment for each
+ checkpointed segment during backward.
+ tie_word_embedding: Tie or clone module weights depending of whether we are
+ using TorchScript or not
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ normalize_before: bool = True,
+ src_attention: bool = True,
+ key_bias: bool = True,
+ activation_type: str = "relu",
+ gradient_checkpointing: bool = False,
+ tie_word_embedding: bool = False,
+ ):
+ super().__init__()
+ attention_dim = encoder_output_size
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
+
+ self.embed = torch.nn.Sequential(
+ torch.nn.Identity() if input_layer == "no_pos" else
+ torch.nn.Embedding(vocab_size, attention_dim),
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
+ positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
+ self.use_output_layer = use_output_layer
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = torch.nn.Identity()
+ self.num_blocks = num_blocks
+ self.decoders = torch.nn.ModuleList([
+ DecoderLayer(
+ attention_dim,
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
+ attention_heads, attention_dim,
+ self_attention_dropout_rate, key_bias),
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
+ attention_heads, attention_dim, src_attention_dropout_rate,
+ key_bias) if src_attention else None,
+ PositionwiseFeedForward(attention_dim, linear_units,
+ dropout_rate, activation),
+ dropout_rate,
+ normalize_before,
+ ) for _ in range(self.num_blocks)
+ ])
+
+ self.gradient_checkpointing = gradient_checkpointing
+ self.tie_word_embedding = tie_word_embedding
+
+ def forward(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
+ reverse_weight: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
+ ys_in_lens: input lengths of this batch (batch)
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
+ with bidirectional decoder
+ reverse_weight: not used in transformer decoder, in order to unify
+ api with bidirectional decode
+ Returns:
+ (tuple): tuple containing:
+ x: decoded token score before softmax (batch, maxlen_out,
+ vocab_size) if use_output_layer is True,
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
+ olens: (batch, )
+ NOTE(xcsong):
+ We pass the `__call__` method of the modules instead of `forward` to the
+ checkpointing API because `__call__` attaches all the hooks of the module.
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+ """
+ tgt = ys_in_pad
+ maxlen = tgt.size(1)
+ # tgt_mask: (B, 1, L)
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
+ tgt_mask = tgt_mask.to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1),
+ device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+ x, _ = self.embed(tgt)
+ if self.gradient_checkpointing and self.training:
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
+ memory_mask)
+ else:
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.use_output_layer:
+ x = self.output_layer(x)
+ olens = tgt_mask.sum(1)
+ return x, torch.tensor(0.0), olens
+
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor) -> torch.Tensor:
+ for layer in self.decoders:
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
+ memory_mask)
+ return x
+
+ @torch.jit.ignore(drop=True)
+ def forward_layers_checkpointed(self, x: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor) -> torch.Tensor:
+ for layer in self.decoders:
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
+ layer.__call__, x, tgt_mask, memory, memory_mask)
+ return x
+
+ def forward_one_step(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+ This is only used for decoding.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x, _ = self.embed(tgt)
+ new_cache = []
+ for i, decoder in enumerate(self.decoders):
+ if cache is None:
+ c = None
+ else:
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask = decoder(x,
+ tgt_mask,
+ memory,
+ memory_mask,
+ cache=c)
+ new_cache.append(x)
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.use_output_layer:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+ return y, new_cache
+
+ def tie_or_clone_weights(self, jit_mode: bool = True):
+ """Tie or clone module weights (between word_emb and output_layer)
+ depending of whether we are using TorchScript or not"""
+ if not self.use_output_layer:
+ return
+ if jit_mode:
+ logging.info("clone emb.weight to output.weight")
+ self.output_layer.weight = torch.nn.Parameter(
+ self.embed[0].weight.clone())
+ else:
+ logging.info("tie emb.weight with output.weight")
+ self.output_layer.weight = self.embed[0].weight
+
+ if getattr(self.output_layer, "bias", None) is not None:
+ self.output_layer.bias.data = torch.nn.functional.pad(
+ self.output_layer.bias.data,
+ (
+ 0,
+ self.output_layer.weight.shape[0] -
+ self.output_layer.bias.shape[0],
+ ),
+ "constant",
+ 0,
+ )
+
+
+class BiTransformerDecoder(torch.nn.Module):
+ """Base class of Transfomer decoder module.
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the hidden units number of position-wise feedforward
+ num_blocks: the number of decoder blocks
+ r_num_blocks: the number of right to left decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before:
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ r_num_blocks: int = 0,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ normalize_before: bool = True,
+ key_bias: bool = True,
+ gradient_checkpointing: bool = False,
+ tie_word_embedding: bool = False,
+ ):
+
+ super().__init__()
+ self.tie_word_embedding = tie_word_embedding
+ self.left_decoder = TransformerDecoder(
+ vocab_size,
+ encoder_output_size,
+ attention_heads,
+ linear_units,
+ num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ input_layer,
+ use_output_layer,
+ normalize_before,
+ key_bias=key_bias,
+ gradient_checkpointing=gradient_checkpointing,
+ tie_word_embedding=tie_word_embedding)
+
+ self.right_decoder = TransformerDecoder(
+ vocab_size,
+ encoder_output_size,
+ attention_heads,
+ linear_units,
+ r_num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ input_layer,
+ use_output_layer,
+ normalize_before,
+ key_bias=key_bias,
+ gradient_checkpointing=gradient_checkpointing,
+ tie_word_embedding=tie_word_embedding)
+
+ def forward(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ r_ys_in_pad: torch.Tensor,
+ reverse_weight: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
+ ys_in_lens: input lengths of this batch (batch)
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
+ used for right to left decoder
+ reverse_weight: used for right to left decoder
+ Returns:
+ (tuple): tuple containing:
+ x: decoded token score before softmax (batch, maxlen_out,
+ vocab_size) if use_output_layer is True,
+ r_x: x: decoded token score (right to left decoder)
+ before softmax (batch, maxlen_out, vocab_size)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
+ ys_in_lens)
+ r_x = torch.tensor(0.0)
+ if reverse_weight > 0.0:
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
+ r_ys_in_pad, ys_in_lens)
+ return l_x, r_x, olens
+
+ def forward_one_step(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+ This is only used for decoding.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
+ tgt_mask, cache)
+
+ def tie_or_clone_weights(self, jit_mode: bool = True):
+ """Tie or clone module weights (between word_emb and output_layer)
+ depending of whether we are using TorchScript or not"""
+ self.left_decoder.tie_or_clone_weights(jit_mode)
+ self.right_decoder.tie_or_clone_weights(jit_mode)
diff --git a/cosyvoice/transformer/decoder_layer.py b/cosyvoice/transformer/decoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..91c7c5d7fb2a8e79cea7705646e5381016f73466
--- /dev/null
+++ b/cosyvoice/transformer/decoder_layer.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Decoder self-attention layer definition."""
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class DecoderLayer(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Inter-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ If `None` is passed, Inter-attention is not used, such as
+ CIF, GPT, and other decoder only model.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: nn.Module,
+ src_attn: Optional[nn.Module],
+ feed_forward: nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an DecoderLayer object."""
+ super().__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ cache: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor
+ (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory
+ (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask
+ (#batch, maxlen_in).
+ cache (torch.Tensor): cached tensors.
+ (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ x = residual + self.dropout(
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(
+ self.src_attn(x, memory, memory, memory_mask)[0])
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
diff --git a/cosyvoice/transformer/embedding.py b/cosyvoice/transformer/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..46130a503f72f103e09d3392077ed352368ce54f
--- /dev/null
+++ b/cosyvoice/transformer/embedding.py
@@ -0,0 +1,293 @@
+# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Positonal Encoding Module."""
+
+import math
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+
+ :param int d_model: embedding dim
+ :param float dropout_rate: dropout rate
+ :param int max_len: maximum input length
+
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
+ """
+
+ def __init__(self,
+ d_model: int,
+ dropout_rate: float,
+ max_len: int = 5000,
+ reverse: bool = False):
+ """Construct an PositionalEncoding object."""
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ self.pe = torch.zeros(self.max_len, self.d_model)
+ position = torch.arange(0, self.max_len,
+ dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / self.d_model))
+ self.pe[:, 0::2] = torch.sin(position * div_term)
+ self.pe[:, 1::2] = torch.cos(position * div_term)
+ self.pe = self.pe.unsqueeze(0)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
+ offset (int, torch.tensor): position offset
+
+ Returns:
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+ torch.Tensor: for compatibility to RelPositionalEncoding
+ """
+
+ self.pe = self.pe.to(x.device)
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ x = x * self.xscale + pos_emb
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(self,
+ offset: Union[int, torch.Tensor],
+ size: int,
+ apply_dropout: bool = True) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ # How to subscript a Union type:
+ # https://github.com/pytorch/pytorch/issues/69434
+ if isinstance(offset, int):
+ assert offset + size <= self.max_len
+ pos_emb = self.pe[:, offset:offset + size]
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
+ assert offset + size <= self.max_len
+ pos_emb = self.pe[:, offset:offset + size]
+ else: # for batched streaming decoding on GPU
+ assert torch.max(offset) + size <= self.max_len
+ index = offset.unsqueeze(1) + \
+ torch.arange(0, size).to(offset.device) # B X T
+ flag = index > 0
+ # remove negative offset
+ index = index * flag
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
+
+ if apply_dropout:
+ pos_emb = self.dropout(pos_emb)
+ return pos_emb
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.pe = self.pe.to(x.device)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class WhisperPositionalEncoding(PositionalEncoding):
+ """ Sinusoids position encoding used in openai-whisper.encoder
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
+ super().__init__(d_model, dropout_rate, max_len)
+ self.xscale = 1.0
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment *
+ torch.arange(d_model // 2))
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
+ inv_timescales[np.newaxis, :]
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+ delattr(self, "pe")
+ self.register_buffer("pe", pe.unsqueeze(0))
+
+
+class LearnablePositionalEncoding(PositionalEncoding):
+ """ Learnable position encoding used in openai-whisper.decoder
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
+ super().__init__(d_model, dropout_rate, max_len)
+ # NOTE(xcsong): overwrite self.pe & self.xscale
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
+ self.xscale = 1.0
+
+
+class NoPositionalEncoding(torch.nn.Module):
+ """ No position encoding
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float):
+ super().__init__()
+ self.d_model = d_model
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Just return zero vector for interface compatibility
+ """
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
+ return self.dropout(x), pos_emb
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return torch.zeros(1, size, self.d_model)
+
+
+class EspnetRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module (new implementation).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Construct an PositionalEncoding object."""
+ super(EspnetRelPositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vecotr and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
+ ]
+ return pos_emb
diff --git a/cosyvoice/transformer/encoder.py b/cosyvoice/transformer/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e8bd230b28c7f55c809a75b6f722a92d7ea089f
--- /dev/null
+++ b/cosyvoice/transformer/encoder.py
@@ -0,0 +1,472 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder definition."""
+from typing import Tuple
+
+import torch
+import torch.utils.checkpoint as ckpt
+
+from cosyvoice.transformer.convolution import ConvolutionModule
+from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
+from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
+from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from cosyvoice.utils.class_utils import (
+ COSYVOICE_EMB_CLASSES,
+ COSYVOICE_SUBSAMPLE_CLASSES,
+ COSYVOICE_ATTENTION_CLASSES,
+ COSYVOICE_ACTIVATION_CLASSES,
+)
+from cosyvoice.utils.mask import make_pad_mask
+from cosyvoice.utils.mask import add_optional_chunk_mask
+
+
+class BaseEncoder(torch.nn.Module):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "abs_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ gradient_checkpointing: bool = False,
+ ):
+ """
+ Args:
+ input_size (int): input dim
+ output_size (int): dimension of attention
+ attention_heads (int): the number of heads of multi head attention
+ linear_units (int): the hidden units number of position-wise feed
+ forward
+ num_blocks (int): the number of decoder blocks
+ dropout_rate (float): dropout rate
+ attention_dropout_rate (float): dropout rate in attention
+ positional_dropout_rate (float): dropout rate after adding
+ positional encoding
+ input_layer (str): input layer type.
+ optional [linear, conv2d, conv2d6, conv2d8]
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
+ normalize_before (bool):
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ static_chunk_size (int): chunk size for static chunk training and
+ decoding
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
+ training or not, You can only use fixed chunk(chunk_size > 0)
+ or dyanmic chunk size(use_dynamic_chunk = True)
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
+ dynamic chunk training
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ gradient_checkpointing: rerunning a forward-pass segment for each
+ checkpointed segment during backward.
+ """
+ super().__init__()
+ self._output_size = output_size
+
+ self.global_cmvn = global_cmvn
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
+ input_size,
+ output_size,
+ dropout_rate,
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
+ positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ self.gradient_checkpointing = gradient_checkpointing
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor.
+
+ Args:
+ xs: padded input tensor (B, T, D)
+ xs_lens: input length (B)
+ decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ Returns:
+ encoder output tensor xs, and subsampled masks
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+ masks: torch.Tensor batch padding mask after subsample
+ (B, 1, T' ~= T/subsample_rate)
+ NOTE(xcsong):
+ We pass the `__call__` method of the modules instead of `forward` to the
+ checkpointing API because `__call__` attaches all the hooks of the module.
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+ """
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(xs, masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks)
+ if self.gradient_checkpointing and self.training:
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
+ mask_pad)
+ else:
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+ # Here we assume the mask is not changed in encoder layers, so just
+ # return the masks before encoder layers, and the masks will be used
+ # for cross attention with decoder later
+ return xs, masks
+
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ return xs
+
+ @torch.jit.ignore(drop=True)
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
+ chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
+ chunk_masks, pos_emb,
+ mask_pad)
+ return xs
+
+ def forward_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """ Forward just one chunk
+
+ Args:
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
+ where `time == (chunk_size - 1) * subsample_rate + \
+ subsample.right_context + 1`
+ offset (int): current offset in encoder output time stamp
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ >=0: actual cache size
+ <0: means all history cache is required
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (elayers, b=1, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+
+ Returns:
+ torch.Tensor: output of current input xs,
+ with shape (b=1, chunk_size, hidden-dim).
+ torch.Tensor: new attention cache required for next chunk, with
+ dynamic shape (elayers, head, ?, d_k * 2)
+ depending on required_cache_size.
+ torch.Tensor: new conformer cnn cache required for next chunk, with
+ same shape as the original cnn_cache.
+
+ """
+ assert xs.size(0) == 1
+ # tmp_masks is just for interface compatibility
+ tmp_masks = torch.ones(1,
+ xs.size(1),
+ device=xs.device,
+ dtype=torch.bool)
+ tmp_masks = tmp_masks.unsqueeze(1)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
+ size=attention_key_size)
+ if required_cache_size < 0:
+ next_cache_start = 0
+ elif required_cache_size == 0:
+ next_cache_start = attention_key_size
+ else:
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
+ r_att_cache = []
+ r_cnn_cache = []
+ for i, layer in enumerate(self.encoders):
+ # NOTE(xcsong): Before layer.forward
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
+ xs, _, new_att_cache, new_cnn_cache = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
+ # NOTE(xcsong): After layer.forward
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
+ # ? may be larger than cache_t1, it depends on required_cache_size
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
+
+ return (xs, r_att_cache, r_cnn_cache)
+
+ def forward_chunk_by_chunk(
+ self,
+ xs: torch.Tensor,
+ decoding_chunk_size: int,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Forward input chunk by chunk with chunk_size like a streaming
+ fashion
+
+ Here we should pay special attention to computation cache in the
+ streaming style forward chunk by chunk. Three things should be taken
+ into account for computation in the current network:
+ 1. transformer/conformer encoder layers output cache
+ 2. convolution in conformer
+ 3. convolution in subsampling
+
+ However, we don't implement subsampling cache for:
+ 1. We can control subsampling module to output the right result by
+ overlapping input instead of cache left context, even though it
+ wastes some computation, but subsampling only takes a very
+ small fraction of computation in the whole model.
+ 2. Typically, there are several covolution layers with subsampling
+ in subsampling module, it is tricky and complicated to do cache
+ with different convolution layers with different subsampling
+ rate.
+ 3. Currently, nn.Sequential is used to stack all the convolution
+ layers in subsampling, we need to rewrite it to make it work
+ with cache, which is not prefered.
+ Args:
+ xs (torch.Tensor): (1, max_len, dim)
+ chunk_size (int): decoding chunk size
+ """
+ assert decoding_chunk_size > 0
+ # The model is trained by static or dynamic chunk
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
+ subsampling = self.embed.subsampling_rate
+ context = self.embed.right_context + 1 # Add current frame
+ stride = subsampling * decoding_chunk_size
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
+ num_frames = xs.size(1)
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ outputs = []
+ offset = 0
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+
+ # Feed forward overlap input step by step
+ for cur in range(0, num_frames - context + 1, stride):
+ end = min(cur + decoding_window, num_frames)
+ chunk_xs = xs[:, cur:end, :]
+ (y, att_cache,
+ cnn_cache) = self.forward_chunk(chunk_xs, offset,
+ required_cache_size, att_cache,
+ cnn_cache)
+ outputs.append(y)
+ offset += y.size(1)
+ ys = torch.cat(outputs, 1)
+ masks = torch.ones((1, 1, ys.size(1)),
+ device=ys.device,
+ dtype=torch.bool)
+ return ys, masks
+
+
+class TransformerEncoder(BaseEncoder):
+ """Transformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "abs_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ key_bias: bool = True,
+ selfattention_layer_type: str = "selfattn",
+ activation_type: str = "relu",
+ gradient_checkpointing: bool = False,
+ ):
+ """ Construct TransformerEncoder
+
+ See Encoder for the meaning of each parameter.
+ """
+ super().__init__(input_size, output_size, attention_heads,
+ linear_units, num_blocks, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate,
+ input_layer, pos_enc_layer_type, normalize_before,
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
+ use_dynamic_left_chunk, gradient_checkpointing)
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
+ self.encoders = torch.nn.ModuleList([
+ TransformerEncoderLayer(
+ output_size,
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
+ output_size,
+ attention_dropout_rate,
+ key_bias),
+ PositionwiseFeedForward(output_size, linear_units,
+ dropout_rate, activation),
+ dropout_rate, normalize_before) for _ in range(num_blocks)
+ ])
+
+
+class ConformerEncoder(BaseEncoder):
+ """Conformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "rel_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ positionwise_conv_kernel_size: int = 1,
+ macaron_style: bool = True,
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ key_bias: bool = True,
+ gradient_checkpointing: bool = False,
+ ):
+ """Construct ConformerEncoder
+
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
+ conv1d layer.
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ selfattention_layer_type (str): Encoder attention layer type,
+ the parameter has no effect now, it's just for configure
+ compatibility.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ """
+ super().__init__(input_size, output_size, attention_heads,
+ linear_units, num_blocks, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate,
+ input_layer, pos_enc_layer_type, normalize_before,
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
+ use_dynamic_left_chunk, gradient_checkpointing)
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
+
+ # self-attention module definition
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ key_bias,
+ )
+ # feed-forward module definition
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ # convolution module definition
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
+ cnn_module_norm, causal)
+
+ self.encoders = torch.nn.ModuleList([
+ ConformerEncoderLayer(
+ output_size,
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
+ *encoder_selfattn_layer_args),
+ PositionwiseFeedForward(*positionwise_layer_args),
+ PositionwiseFeedForward(
+ *positionwise_layer_args) if macaron_style else None,
+ ConvolutionModule(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before,
+ ) for _ in range(num_blocks)
+ ])
diff --git a/cosyvoice/transformer/encoder_layer.py b/cosyvoice/transformer/encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfd758bc1cc7780aa4f6a322a264c879b74a6cfe
--- /dev/null
+++ b/cosyvoice/transformer/encoder_layer.py
@@ -0,0 +1,236 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder self-attention layer definition."""
+
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): just for interface compatibility
+ to ConformerEncoderLayer
+ mask_pad (torch.Tensor): does not used in transformer layer,
+ just for unified api with conformer.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2), not used here, it's for interface
+ compatibility to ConformerEncoderLayer.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
+
+ """
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ return x, mask, new_att_cache, fake_cnn_cache
+
+
+class ConformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: Optional[nn.Module] = None,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float = 0.1,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
+ self.norm_final = nn.LayerNorm(
+ size, eps=1e-5) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): positional encoding, must not be None
+ for ConformerEncoderLayer.
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
+ (#batch, 1,time), (0, 0, 0) means fake mask.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
+ """
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
+ att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ # Fake new cnn cache here, and then change it in conv_module
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+ x = residual + self.dropout(x)
+
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
diff --git a/cosyvoice/transformer/label_smoothing_loss.py b/cosyvoice/transformer/label_smoothing_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..feacabf09609ee6eb047c89ce18d372256c72c71
--- /dev/null
+++ b/cosyvoice/transformer/label_smoothing_loss.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Label smoothing module."""
+
+import torch
+from torch import nn
+
+
+class LabelSmoothingLoss(nn.Module):
+ """Label-smoothing loss.
+
+ In a standard CE loss, the label's data distribution is:
+ [0,1,2] ->
+ [
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0],
+ ]
+
+ In the smoothing version CE Loss,some probabilities
+ are taken from the true label prob (1.0) and are divided
+ among other labels.
+
+ e.g.
+ smoothing=0.1
+ [0,1,2] ->
+ [
+ [0.9, 0.05, 0.05],
+ [0.05, 0.9, 0.05],
+ [0.05, 0.05, 0.9],
+ ]
+
+ Args:
+ size (int): the number of class
+ padding_idx (int): padding class id which will be ignored for loss
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
+ normalize_length (bool):
+ normalize loss by sequence length if True
+ normalize loss by batch size if False
+ """
+
+ def __init__(self,
+ size: int,
+ padding_idx: int,
+ smoothing: float,
+ normalize_length: bool = False):
+ """Construct an LabelSmoothingLoss object."""
+ super(LabelSmoothingLoss, self).__init__()
+ self.criterion = nn.KLDivLoss(reduction="none")
+ self.padding_idx = padding_idx
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.size = size
+ self.normalize_length = normalize_length
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """Compute loss between x and target.
+
+ The model outputs and data labels tensors are flatten to
+ (batch*seqlen, class) shape and a mask is applied to the
+ padding part which should not be calculated for loss.
+
+ Args:
+ x (torch.Tensor): prediction (batch, seqlen, class)
+ target (torch.Tensor):
+ target signal masked with self.padding_id (batch, seqlen)
+ Returns:
+ loss (torch.Tensor) : The KL loss, scalar float value
+ """
+ assert x.size(2) == self.size
+ batch_size = x.size(0)
+ x = x.view(-1, self.size)
+ target = target.view(-1)
+ # use zeros_like instead of torch.no_grad() for true_dist,
+ # since no_grad() can not be exported by JIT
+ true_dist = torch.zeros_like(x)
+ true_dist.fill_(self.smoothing / (self.size - 1))
+ ignore = target == self.padding_idx # (B,)
+ total = len(target) - ignore.sum().item()
+ target = target.masked_fill(ignore, 0) # avoid -1 index
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
+ denom = total if self.normalize_length else batch_size
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
diff --git a/cosyvoice/transformer/positionwise_feed_forward.py b/cosyvoice/transformer/positionwise_feed_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a2cf6e7315e3a5ed2794423daff0a59cc5b208
--- /dev/null
+++ b/cosyvoice/transformer/positionwise_feed_forward.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+
+class PositionwiseFeedForward(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ FeedForward are appied on each position of the sequence.
+ The output dim is same with the input dim.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ """Construct a PositionwiseFeedForward object."""
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
+ self.activation = activation
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+ """
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
+
+
+class MoEFFNLayer(torch.nn.Module):
+ """
+ Mixture of expert with Positionwise feed forward layer
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
+ The output dim is same with the input dim.
+
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
+ Args:
+ n_expert: number of expert.
+ n_expert_per_token: The actual number of experts used for each frame
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ n_expert: int,
+ n_expert_per_token: int,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ super(MoEFFNLayer, self).__init__()
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
+ self.experts = torch.nn.ModuleList(
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
+ activation) for _ in range(n_expert))
+ self.n_expert_per_token = n_expert_per_token
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Foward function.
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+
+ """
+ B, L, D = xs.size(
+ ) # batch size, sequence length, embedding dimension (idim)
+ xs = xs.view(-1, D) # (B*L, D)
+ router = self.gate(xs) # (B*L, n_expert)
+ logits, indices = torch.topk(
+ router, self.n_expert_per_token
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
+ weights = torch.nn.functional.softmax(
+ logits, dim=1,
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
+ output = torch.zeros_like(xs) # (B*L, D)
+ for i, expert in enumerate(self.experts):
+ mask = indices == i
+ batch_idx, ith_expert = torch.where(mask)
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
+ xs[batch_idx])
+ return output.view(B, L, D)
diff --git a/cosyvoice/transformer/subsampling.py b/cosyvoice/transformer/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e17c2e324e3afb24e1b619effe29cef07c9c5b3a
--- /dev/null
+++ b/cosyvoice/transformer/subsampling.py
@@ -0,0 +1,383 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Subsampling layer definition."""
+
+from typing import Tuple, Union
+
+import torch
+
+
+class BaseSubsampling(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return self.pos_enc.position_encoding(offset, size)
+
+
+class EmbedinigNoSubsampling(BaseSubsampling):
+ """Embedding input without subsampling
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ super().__init__()
+ self.embed = torch.nn.Embedding(idim, odim)
+ self.pos_enc = pos_enc_class
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.embed(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class LinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class Conv1dSubsampling2(BaseSubsampling):
+ """Convolutional 1D subsampling (to 1/2 length).
+ It is designed for Whisper, ref:
+ https://github.com/openai/whisper/blob/main/whisper/model.py
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv1dSubsampling2 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
+ torch.nn.GELU(),
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
+ torch.nn.GELU(),
+ )
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 2
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
+ self.right_context = 4
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
+ torch.Tensor: positional encoding
+
+ """
+ time = x.size(1)
+ x = x.transpose(1, 2) # (b, f, t)
+ x = self.conv(x)
+ x = x.transpose(1, 2) # (b, t, f)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
+
+
+class Conv2dSubsampling4(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling4 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 4
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
+ self.right_context = 6
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ torch.Tensor: positional encoding
+
+ """
+ x = x.unsqueeze(1) # (b, c=1, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
+
+
+class Conv2dSubsampling6(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/6 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling6 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 5, 3),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
+ odim)
+ self.pos_enc = pos_enc_class
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
+ self.subsampling_rate = 6
+ self.right_context = 10
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 6.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 6.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
+
+
+class Conv2dSubsampling8(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/8 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling8 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 8
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
+ self.right_context = 14
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 8.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 8.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
+
+
+class LegacyLinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
diff --git a/cosyvoice/utils/__init__.py b/cosyvoice/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/utils/class_utils.py b/cosyvoice/utils/class_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8cc4714586161487c7019153b960bfb2a029e36
--- /dev/null
+++ b/cosyvoice/utils/class_utils.py
@@ -0,0 +1,70 @@
+# Copyright [2023-11-28]
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from cosyvoice.transformer.activation import Swish
+from cosyvoice.transformer.subsampling import (
+ LinearNoSubsampling,
+ EmbedinigNoSubsampling,
+ Conv1dSubsampling2,
+ Conv2dSubsampling4,
+ Conv2dSubsampling6,
+ Conv2dSubsampling8,
+)
+from cosyvoice.transformer.embedding import (PositionalEncoding,
+ RelPositionalEncoding,
+ WhisperPositionalEncoding,
+ LearnablePositionalEncoding,
+ NoPositionalEncoding)
+from cosyvoice.transformer.attention import (MultiHeadedAttention,
+ RelPositionMultiHeadedAttention)
+from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
+from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
+
+
+COSYVOICE_ACTIVATION_CLASSES = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": getattr(torch.nn, "SiLU", Swish),
+ "gelu": torch.nn.GELU,
+}
+
+COSYVOICE_SUBSAMPLE_CLASSES = {
+ "linear": LinearNoSubsampling,
+ "linear_legacy": LegacyLinearNoSubsampling,
+ "embed": EmbedinigNoSubsampling,
+ "conv1d2": Conv1dSubsampling2,
+ "conv2d": Conv2dSubsampling4,
+ "conv2d6": Conv2dSubsampling6,
+ "conv2d8": Conv2dSubsampling8,
+ 'paraformer_dummy': torch.nn.Identity
+}
+
+COSYVOICE_EMB_CLASSES = {
+ "embed": PositionalEncoding,
+ "abs_pos": PositionalEncoding,
+ "rel_pos": RelPositionalEncoding,
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
+ "no_pos": NoPositionalEncoding,
+ "abs_pos_whisper": WhisperPositionalEncoding,
+ "embed_learnable_pe": LearnablePositionalEncoding,
+}
+
+COSYVOICE_ATTENTION_CLASSES = {
+ "selfattn": MultiHeadedAttention,
+ "rel_selfattn": RelPositionMultiHeadedAttention,
+}
diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..73b438ebb67d2bc552edf4f0d773f2991dcb2493
--- /dev/null
+++ b/cosyvoice/utils/common.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Unility functions for Transformer."""
+
+from typing import List
+
+import torch
+
+IGNORE_ID = -1
+
+
+def pad_list(xs: List[torch.Tensor], pad_value: int):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ max_len = max([len(item) for item in xs])
+ batchs = len(xs)
+ ndim = xs[0].ndim
+ if ndim == 1:
+ pad_res = torch.zeros(batchs,
+ max_len,
+ dtype=xs[0].dtype,
+ device=xs[0].device)
+ elif ndim == 2:
+ pad_res = torch.zeros(batchs,
+ max_len,
+ xs[0].shape[1],
+ dtype=xs[0].dtype,
+ device=xs[0].device)
+ elif ndim == 3:
+ pad_res = torch.zeros(batchs,
+ max_len,
+ xs[0].shape[1],
+ xs[0].shape[2],
+ dtype=xs[0].dtype,
+ device=xs[0].device)
+ else:
+ raise ValueError(f"Unsupported ndim: {ndim}")
+ pad_res.fill_(pad_value)
+ for i in range(batchs):
+ pad_res[i, :len(xs[i])] = xs[i]
+ return pad_res
+
+
+def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
+ ignore_label: int) -> torch.Tensor:
+ """Calculate accuracy.
+
+ Args:
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ torch.Tensor: Accuracy value (0.0 - 1.0).
+
+ """
+ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
+ pad_outputs.size(1)).argmax(2)
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
+ denominator = torch.sum(mask)
+ return (numerator / denominator).detach()
diff --git a/cosyvoice/utils/executor.py b/cosyvoice/utils/executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c12e52df9f41bf8d8c05a65a069f898aec3ef6ca
--- /dev/null
+++ b/cosyvoice/utils/executor.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import nullcontext
+import os
+
+import torch
+import torch.distributed as dist
+
+from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
+
+
+class Executor:
+
+ def __init__(self):
+ self.step = 0
+ self.epoch = 0
+ self.rank = int(os.environ.get('RANK', 0))
+ self.device = torch.device('cuda:{}'.format(self.rank))
+
+ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
+ ''' Train one epoch
+ '''
+
+ lr = optimizer.param_groups[0]['lr']
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
+ logging.info('using accumulate grad, new batch size is {} times'
+ ' larger than before'.format(info_dict['accum_grad']))
+ # A context manager to be used in conjunction with an instance of
+ # torch.nn.parallel.DistributedDataParallel to be able to train
+ # with uneven inputs across participating processes.
+ model.train()
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
+ with model_context():
+ for batch_idx, batch_dict in enumerate(train_data_loader):
+ info_dict["tag"] = "TRAIN"
+ info_dict["step"] = self.step
+ info_dict["epoch"] = self.epoch
+ info_dict["batch_idx"] = batch_idx
+ if cosyvoice_join(group_join, info_dict):
+ break
+
+ # Disable gradient synchronizations across DDP processes.
+ # Within this context, gradients will be accumulated on module
+ # variables, which will later be synchronized.
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
+ context = model.no_sync
+ # Used for single gpu training and DDP gradient synchronization
+ # processes.
+ else:
+ context = nullcontext
+
+ with context():
+ info_dict = batch_forward(model, batch_dict, info_dict)
+ info_dict = batch_backward(model, info_dict)
+
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
+ log_per_step(writer, info_dict)
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
+ dist.barrier()
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
+ model.train()
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
+ self.step += 1
+ dist.barrier()
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
+
+ @torch.inference_mode()
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
+ ''' Cross validation on
+ '''
+ logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
+ model.eval()
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
+ info_dict["tag"] = "CV"
+ info_dict["step"] = self.step
+ info_dict["epoch"] = self.epoch
+ info_dict["batch_idx"] = batch_idx
+
+ num_utts = len(batch_dict["utts"])
+ total_num_utts += num_utts
+
+ info_dict = batch_forward(model, batch_dict, info_dict)
+
+ for k, v in info_dict['loss_dict'].items():
+ if k not in total_loss_dict:
+ total_loss_dict[k] = []
+ total_loss_dict[k].append(v.item() * num_utts)
+ log_per_step(None, info_dict)
+ for k, v in total_loss_dict.items():
+ total_loss_dict[k] = sum(v) / total_num_utts
+ info_dict['loss_dict'] = total_loss_dict
+ log_per_save(writer, info_dict)
+ model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
+ save_model(model, model_name, info_dict)
diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..92c448b9cc0860087747262cc75eb6cf0da722f2
--- /dev/null
+++ b/cosyvoice/utils/file_utils.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import torchaudio
+
+
+def read_lists(list_file):
+ lists = []
+ with open(list_file, 'r', encoding='utf8') as fin:
+ for line in fin:
+ lists.append(line.strip())
+ return lists
+
+def read_json_lists(list_file):
+ lists = read_lists(list_file)
+ results = {}
+ for fn in lists:
+ with open(fn, 'r', encoding='utf8') as fin:
+ results.update(json.load(fin))
+ return results
+
+def load_wav(wav, target_sr):
+ speech, sample_rate = torchaudio.load(wav)
+ speech = speech.mean(dim=0, keepdim=True)
+ if sample_rate != target_sr:
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
+ return speech
diff --git a/cosyvoice/utils/frontend_utils.py b/cosyvoice/utils/frontend_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee829ff4c9fb9aff36b594f0e33825b59918617
--- /dev/null
+++ b/cosyvoice/utils/frontend_utils.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
+
+# whether contain chinese character
+def contains_chinese(text):
+ return bool(chinese_char_pattern.search(text))
+
+
+# replace special symbol
+def replace_corner_mark(text):
+ text = text.replace('²', '平方')
+ text = text.replace('³', '立方')
+ return text
+
+
+# remove meaningless symbol
+def remove_bracket(text):
+ text = text.replace('(', '').replace(')', '')
+ text = text.replace('【', '').replace('】', '')
+ text = text.replace('`', '').replace('`', '')
+ text = text.replace("——", " ")
+ return text
+
+
+# spell Arabic numerals
+def spell_out_number(text: str, inflect_parser):
+ new_text = []
+ st = None
+ for i, c in enumerate(text):
+ if not c.isdigit():
+ if st is not None:
+ num_str = inflect_parser.number_to_words(text[st: i])
+ new_text.append(num_str)
+ st = None
+ new_text.append(c)
+ else:
+ if st is None:
+ st = i
+ if st is not None and st < len(text):
+ num_str = inflect_parser.number_to_words(text[st:])
+ new_text.append(num_str)
+ return ''.join(new_text)
+
+
+# split paragrah logic:
+# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
+# 2. cal sentence len according to lang
+# 3. split sentence according to puncatation
+def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
+ def calc_utt_length(_text: str):
+ if lang == "zh":
+ return len(_text)
+ else:
+ return len(tokenize(_text))
+
+ def should_merge(_text: str):
+ if lang == "zh":
+ return len(_text) < merge_len
+ else:
+ return len(tokenize(_text)) < merge_len
+
+ if lang == "zh":
+ pounc = ['。', '?', '!', ';', ':', '.', '?', '!', ';']
+ else:
+ pounc = ['.', '?', '!', ';', ':']
+ if comma_split:
+ pounc.extend([',', ','])
+ st = 0
+ utts = []
+ for i, c in enumerate(text):
+ if c in pounc:
+ if len(text[st: i]) > 0:
+ utts.append(text[st: i] + c)
+ if i + 1 < len(text) and text[i + 1] in ['"', '”']:
+ tmp = utts.pop(-1)
+ utts.append(tmp + text[i + 1])
+ st = i + 2
+ else:
+ st = i + 1
+ final_utts = []
+ cur_utt = ""
+ for utt in utts:
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
+ final_utts.append(cur_utt)
+ cur_utt = ""
+ cur_utt = cur_utt + utt
+ if len(cur_utt) > 0:
+ if should_merge(cur_utt) and len(final_utts) != 0:
+ final_utts[-1] = final_utts[-1] + cur_utt
+ else:
+ final_utts.append(cur_utt)
+
+ return final_utts
+
+
+# remove blank between chinese character
+def replace_blank(text: str):
+ out_str = []
+ for i, c in enumerate(text):
+ if c == " ":
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
+ (text[i - 1].isascii() and text[i - 1] != " ")):
+ out_str.append(c)
+ else:
+ out_str.append(c)
+ return "".join(out_str)
diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b460bbd5adb4bd61d643ace71400a14fe314236
--- /dev/null
+++ b/cosyvoice/utils/mask.py
@@ -0,0 +1,227 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+'''
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
+ return torch.tril(ret)
+'''
+
+
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ arange = torch.arange(size, device=device)
+ mask = arange.expand(size, size)
+ arange = arange.unsqueeze(-1)
+ mask = mask <= arange
+ return mask
+
+
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ for i in range(size):
+ if num_left_chunks < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+ ending = min((i // chunk_size + 1) * chunk_size, size)
+ ret[i, start:ending] = True
+ return ret
+
+
+def add_optional_chunk_mask(xs: torch.Tensor,
+ masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int,
+ static_chunk_size: int,
+ num_decoding_left_chunks: int,
+ enable_full_context: bool = True):
+ """ Apply optional mask for encoder.
+
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ enable_full_context (bool):
+ True: chunk size is either [1, 25] or full context(max_len)
+ False: chunk size ~ U[1, 25]
+
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = num_decoding_left_chunks
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1, )).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2 and enable_full_context:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks,
+ (1, )).item()
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = num_decoding_left_chunks
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ return chunk_masks
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """Make mask tensor containing indices of padded part.
+
+ See description of make_non_pad_mask.
+
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: Mask tensor containing indices of padded part.
+
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ """
+ batch_size = lengths.size(0)
+ max_len = max_len if max_len > 0 else lengths.max().item()
+ seq_range = torch.arange(0,
+ max_len,
+ dtype=torch.int64,
+ device=lengths.device)
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
+ seq_length_expand = lengths.unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+ return mask
diff --git a/cosyvoice/utils/scheduler.py b/cosyvoice/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed1ea0d1e78c0d9f1188bec02244cbff414f894
--- /dev/null
+++ b/cosyvoice/utils/scheduler.py
@@ -0,0 +1,717 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+# 2022 Ximalaya Inc (Yuguang Yang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+# NeMo(https://github.com/NVIDIA/NeMo)
+
+from typing import Union
+
+import math
+import warnings
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class WarmupLR(_LRScheduler):
+ """The WarmupLR scheduler
+
+ This scheduler is almost same as NoamLR Scheduler except for following
+ difference:
+
+ NoamLR:
+ lr = optimizer.lr * model_size ** -0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+ WarmupLR:
+ lr = optimizer.lr * warmup_step ** 0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
+
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: Union[int, float] = 25000,
+ last_epoch: int = -1,
+ ):
+ self.warmup_steps = warmup_steps
+
+ # __init__() must be invoked before setting field
+ # because step() is also invoked in __init__()
+ super().__init__(optimizer, last_epoch)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
+
+ def get_lr(self):
+ step_num = self.last_epoch + 1
+ if self.warmup_steps == 0:
+ return [lr * step_num**-0.5 for lr in self.base_lrs]
+ else:
+ return [
+ lr * self.warmup_steps**0.5 *
+ min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
+ for lr in self.base_lrs
+ ]
+
+ def set_step(self, step: int):
+ self.last_epoch = step
+
+
+class WarmupPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1):
+ assert not (warmup_steps is not None and warmup_ratio is not None),\
+ "Either use particular number of step or ratio"
+ assert warmup_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ if step <= self.warmup_steps and self.warmup_steps > 0:
+ return self._get_warmup_lr(step)
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_warmup_lr(self, step):
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class SquareRootConstantPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(self,
+ optimizer,
+ *,
+ constant_steps=None,
+ constant_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1):
+ assert not (constant_steps is not None
+ and constant_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert constant_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if constant_steps is not None:
+ self.constant_steps = constant_steps
+ elif constant_ratio is not None:
+ self.constant_steps = int(constant_ratio * max_steps)
+ else:
+ self.constant_steps = 0
+
+ self.constant_lr = 1 / (constant_steps**0.5)
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ if step <= self.constant_steps:
+ return [self.constant_lr for _ in self.base_lrs]
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class WarmupHoldPolicy(WarmupPolicy):
+ """Variant of WarmupPolicy which maintains high
+ learning rate for a defined number of steps.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to
+ hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ hold_steps=None,
+ hold_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (hold_steps is not None and hold_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert hold_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ self.min_lr = min_lr
+ self._last_warmup_lr = 0.0
+
+ # Necessary to duplicate as class attributes are hidden in inner class
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if hold_steps is not None:
+ self.hold_steps = hold_steps + self.warmup_steps
+ elif hold_ratio is not None:
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
+ else:
+ self.hold_steps = 0
+
+ super().__init__(
+ optimizer,
+ warmup_steps=warmup_steps,
+ warmup_ratio=warmup_ratio,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ )
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler,"
+ " "
+ "please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ # Warmup phase
+ if step <= self.warmup_steps and self.warmup_steps > 0:
+ return self._get_warmup_lr(step)
+
+ # Hold phase
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
+ return self.base_lrs
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+
+class WarmupAnnealHoldPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ min_lr: Minimum lr to hold the learning rate after decay at.
+ constant_steps: Number of steps to keep lr constant at.
+ constant_ratio: Ratio of steps to keep lr constant.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ constant_steps=None,
+ constant_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (warmup_steps is not None
+ and warmup_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert not (constant_steps is not None
+ and constant_ratio is not None), \
+ "Either use constant_steps or constant_ratio"
+ assert warmup_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if constant_steps is not None:
+ self.constant_steps = constant_steps
+ elif constant_ratio is not None:
+ self.constant_steps = int(constant_ratio * max_steps)
+ else:
+ self.constant_steps = 0
+
+ self.decay_steps = max_steps - (self.constant_steps +
+ self.warmup_steps)
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ # Warmup steps
+ if self.warmup_steps > 0 and step <= self.warmup_steps:
+ return self._get_warmup_lr(step)
+
+ # Constant steps after warmup and decay
+ if self.constant_steps > 0 and (
+ self.warmup_steps + self.decay_steps) < step <= self.max_steps:
+ return self._get_constant_lr(step)
+
+ # Min lr after max steps of updates
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_warmup_lr(self, step):
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ def _get_constant_lr(self, step):
+ return [self.min_lr for _ in self.base_lrs]
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps)**0.5
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _square_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps)**2
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _cosine_annealing(initial_lr, step, max_steps, min_lr):
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+ out_lr = (initial_lr - min_lr) * mult + min_lr
+ return out_lr
+
+
+def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
+ decay_steps, min_lr):
+ assert max_lr > min_lr
+ # Use linear warmup for the initial part.
+ if warmup_steps > 0 and step <= warmup_steps:
+ return max_lr * float(step) / float(warmup_steps)
+
+ # For any steps larger than `decay_steps`, use `min_lr`.
+ if step > warmup_steps + decay_steps:
+ return min_lr
+
+ # If we are done with the warmup period, use the decay style.
+ num_steps_ = step - warmup_steps
+ decay_steps_ = decay_steps
+ decay_ratio = float(num_steps_) / float(decay_steps_)
+ assert decay_ratio >= 0.0
+ assert decay_ratio <= 1.0
+ delta_lr = max_lr - min_lr
+
+ coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
+
+ return min_lr + coeff * delta_lr
+
+
+def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
+ if cycle:
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
+ decay_steps *= multiplier
+ else:
+ step = min(step, decay_steps)
+ p = step / decay_steps
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
+ lr += min_lr
+ return lr
+
+
+def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
+ decay_rate, min_lr):
+ # hold_steps = total number of steps
+ # to hold the LR, not the warmup + hold steps.
+ T_warmup_decay = max(1, warmup_steps**decay_rate)
+ T_hold_decay = max(1, (step - hold_steps)**decay_rate)
+ lr = (initial_lr * T_warmup_decay) / T_hold_decay
+ lr = max(lr, min_lr)
+ return lr
+
+
+class SquareAnnealing(WarmupPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ min_lr=1e-5,
+ last_epoch=-1,
+ **kwargs):
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _square_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ ) for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class SquareRootAnnealing(WarmupPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ min_lr=0,
+ last_epoch=-1,
+ **kwargs):
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _squareroot_annealing(initial_lr=initial_lr,
+ step=step,
+ max_steps=self.max_steps,
+ min_lr=self.min_lr)
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class CosineAnnealing(WarmupAnnealHoldPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ min_lr=0,
+ last_epoch=-1,
+ **kwargs):
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate "
+ f"that was lower than the minimum learning rate.")
+
+ if self.constant_steps is None or self.constant_steps == 0:
+ new_lrs = [
+ _cosine_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ ) for initial_lr in self.base_lrs
+ ]
+ else:
+ new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
+ return new_lrs
+
+ def _get_warmup_lr(self, step):
+ if self.constant_steps is None or self.constant_steps == 0:
+ return super()._get_warmup_lr(step)
+ else:
+ # Use linear warmup for the initial part.
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
+
+ def _get_constant_lr(self, step):
+ # Only called when `constant_steps` > 0.
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
+
+ def _get_linear_warmup_with_cosine_annealing_lr(self, step):
+ # Cosine Schedule for Megatron LM,
+ # slightly different warmup schedule + constant LR at the end.
+ new_lrs = [
+ _linear_warmup_with_cosine_annealing(
+ max_lr=self.base_lrs[0],
+ warmup_steps=self.warmup_steps,
+ step=step,
+ decay_steps=self.decay_steps,
+ min_lr=self.min_lr,
+ ) for _ in self.base_lrs
+ ]
+ return new_lrs
+
+
+class NoamAnnealing(_LRScheduler):
+
+ def __init__(self,
+ optimizer,
+ *,
+ d_model,
+ warmup_steps=None,
+ warmup_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1):
+ self._normalize = d_model**(-0.5)
+ assert not (warmup_steps is not None
+ and warmup_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert warmup_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = max(1, self.last_epoch)
+
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate "
+ f"that was lower than the minimum learning rate.")
+
+ new_lrs = [
+ self._noam_annealing(initial_lr=initial_lr, step=step)
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+ def _noam_annealing(self, initial_lr, step):
+ if self.warmup_steps > 0:
+ mult = self._normalize * min(step**(-0.5),
+ step * (self.warmup_steps**(-1.5)))
+ else:
+ mult = self._normalize * step**(-0.5)
+
+ out_lr = initial_lr * mult
+ if step > self.warmup_steps:
+ out_lr = max(out_lr, self.min_lr)
+ return out_lr
+
+
+class NoamHoldAnnealing(WarmupHoldPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ decay_rate=0.5,
+ min_lr=0.0,
+ last_epoch=-1,
+ **kwargs):
+ """
+ From Nemo:
+ Implementation of the Noam Hold Annealing policy
+ from the SqueezeFormer paper.
+
+ Unlike NoamAnnealing, the peak learning rate
+ can be explicitly set for this scheduler.
+ The schedule first performs linear warmup,
+ then holds the peak LR, then decays with some schedule for
+ the remainder of the steps.
+ Therefore the min-lr is still dependent
+ on the hyper parameters selected.
+
+ It's schedule is determined by three factors-
+
+ Warmup Steps: Initial stage, where linear warmup
+ occurs uptil the peak LR is reached. Unlike NoamAnnealing,
+ the peak LR is explicitly stated here instead of a scaling factor.
+
+ Hold Steps: Intermediate stage, where the peak LR
+ is maintained for some number of steps. In this region,
+ the high peak LR allows the model to converge faster
+ if training is stable. However the high LR
+ may also cause instability during training.
+ Should usually be a significant fraction of training
+ steps (around 30-40% of the entire training steps).
+
+ Decay Steps: Final stage, where the LR rapidly decays
+ with some scaling rate (set by decay rate).
+ To attain Noam decay, use 0.5,
+ for Squeezeformer recommended decay, use 1.0.
+ The fast decay after prolonged high LR during
+ hold phase allows for rapid convergence.
+
+ References:
+ - [Squeezeformer:
+ An Efficient Transformer for Automatic Speech Recognition]
+ (https://arxiv.org/abs/2206.00888)
+
+ Args:
+ optimizer: Pytorch compatible Optimizer object.
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to
+ hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ decay_rate: Float value describing the polynomial decay
+ after the hold period. Default value
+ of 0.5 corresponds to Noam decay.
+ min_lr: Minimum learning rate.
+ """
+ self.decay_rate = decay_rate
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ if self.warmup_steps is None or self.warmup_steps == 0:
+ raise ValueError(
+ "Noam scheduler cannot be used without warmup steps")
+
+ if self.hold_steps > 0:
+ hold_steps = self.hold_steps - self.warmup_steps
+ else:
+ hold_steps = 0
+
+ new_lrs = [
+ _noam_hold_annealing(
+ initial_lr,
+ step=step,
+ warmup_steps=self.warmup_steps,
+ hold_steps=hold_steps,
+ decay_rate=self.decay_rate,
+ min_lr=self.min_lr,
+ ) for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+ def set_step(self, step: int):
+ self.last_epoch = step
diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..df3a32102c65f09d10097f6669824e336f1d2066
--- /dev/null
+++ b/cosyvoice/utils/train_utils.py
@@ -0,0 +1,286 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
+# 2023 Horizon Inc. (authors: Xingchen Song)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import nullcontext
+import logging
+import os
+import torch
+import json
+import re
+import datetime
+import yaml
+
+import deepspeed
+import torch.optim as optim
+import torch.distributed as dist
+
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data import DataLoader
+from torch.nn.utils import clip_grad_norm_
+
+from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
+
+from cosyvoice.dataset.dataset import Dataset
+from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
+
+
+def init_distributed(args):
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ rank = int(os.environ.get('RANK', 0))
+ logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
+ ', rank {}, world_size {}'.format(rank, world_size))
+ if args.train_engine == 'torch_ddp':
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group(args.dist_backend)
+ else:
+ deepspeed.init_distributed(dist_backend=args.dist_backend)
+ return world_size, local_rank, rank
+
+
+def init_dataset_and_dataloader(args, configs):
+ train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
+ cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
+
+ # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
+ train_data_loader = DataLoader(train_dataset,
+ batch_size=None,
+ pin_memory=args.pin_memory,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch)
+ cv_data_loader = DataLoader(cv_dataset,
+ batch_size=None,
+ pin_memory=args.pin_memory,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch)
+ return train_dataset, cv_dataset, train_data_loader, cv_data_loader
+
+
+
+def check_modify_and_save_config(args, configs):
+ if args.train_engine == "torch_ddp":
+ configs['train_conf']["dtype"] = 'fp32'
+ else:
+ with open(args.deepspeed_config, 'r') as fin:
+ ds_configs = json.load(fin)
+ if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
+ configs['train_conf']["dtype"] = "fp16"
+ elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
+ configs['train_conf']["dtype"] = "bf16"
+ else:
+ configs['train_conf']["dtype"] = "fp32"
+ assert ds_configs["train_micro_batch_size_per_gpu"] == 1
+ # if use deepspeed, override ddp config
+ configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
+ configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
+ configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
+ configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
+ return configs
+
+
+def wrap_cuda_model(args, model):
+ local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
+ if args.train_engine == "torch_ddp": # native pytorch ddp
+ assert (torch.cuda.is_available())
+ model.cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
+ else:
+ if int(os.environ.get('RANK', 0)) == 0:
+ logging.info("Estimating model states memory needs (zero2)...")
+ estimate_zero2_model_states_mem_needs_all_live(
+ model,
+ num_gpus_per_node=local_world_size,
+ num_nodes=world_size // local_world_size)
+ return model
+
+
+def init_optimizer_and_scheduler(args, configs, model):
+ if configs['train_conf']['optim'] == 'adam':
+ optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
+ elif configs['train_conf']['optim'] == 'adamw':
+ optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
+ else:
+ raise ValueError("unknown optimizer: " + configs['train_conf'])
+
+ if configs['train_conf']['scheduler'] == 'warmuplr':
+ scheduler_type = WarmupLR
+ scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
+ elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
+ scheduler_type = NoamHoldAnnealing
+ scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+ else:
+ raise ValueError("unknown scheduler: " + configs['train_conf'])
+
+ # use deepspeed optimizer for speedup
+ if args.train_engine == "deepspeed":
+ def scheduler(opt):
+ return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=args,
+ model=model,
+ optimizer=None,
+ lr_scheduler=scheduler,
+ model_parameters=model.parameters())
+
+ return model, optimizer, scheduler
+
+
+def init_summarywriter(args):
+ writer = None
+ if int(os.environ.get('RANK', 0)) == 0:
+ os.makedirs(args.model_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+ return writer
+
+
+def save_model(model, model_name, info_dict):
+ rank = int(os.environ.get('RANK', 0))
+ model_dir = info_dict["model_dir"]
+ save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
+
+ if info_dict["train_engine"] == "torch_ddp":
+ if rank == 0:
+ torch.save(model.module.state_dict(), save_model_path)
+ else:
+ with torch.no_grad():
+ model.save_checkpoint(save_dir=model_dir,
+ tag=model_name,
+ client_state=info_dict)
+ if rank == 0:
+ info_path = re.sub('.pt$', '.yaml', save_model_path)
+ info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
+ with open(info_path, 'w') as fout:
+ data = yaml.dump(info_dict)
+ fout.write(data)
+ logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
+
+
+def cosyvoice_join(group_join, info_dict):
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ rank = int(os.environ.get('RANK', 0))
+
+ if info_dict["batch_idx"] != 0:
+ # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
+ try:
+ dist.monitored_barrier(group=group_join,
+ timeout=group_join.options._timeout)
+ return False
+ except RuntimeError as e:
+ logging.info("Detected uneven workload distribution: {}\n".format(e) +
+ "Break current worker to manually join all workers, " +
+ "world_size {}, current rank {}, current local_rank {}\n".
+ format(world_size, rank, local_rank))
+ return True
+ else:
+ return False
+
+
+def batch_forward(model, batch, info_dict):
+ device = int(os.environ.get('LOCAL_RANK', 0))
+
+ dtype = info_dict["dtype"]
+ if dtype == "fp16":
+ dtype = torch.float16
+ elif dtype == "bf16":
+ dtype = torch.bfloat16
+ else: # fp32
+ dtype = torch.float32
+
+ if info_dict['train_engine'] == 'torch_ddp':
+ autocast = nullcontext()
+ else:
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
+
+ with autocast:
+ info_dict['loss_dict'] = model(batch, device)
+ return info_dict
+
+
+def batch_backward(model, info_dict):
+ if info_dict["train_engine"] == "deepspeed":
+ scaled_loss = model.backward(info_dict['loss_dict']['loss'])
+ else:
+ scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
+ scaled_loss.backward()
+
+ info_dict['loss_dict']['loss'] = scaled_loss
+ return info_dict
+
+
+def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
+ grad_norm = 0.0
+ if info_dict['train_engine'] == "deepspeed":
+ info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
+ model.step()
+ grad_norm = model.get_global_grad_norm()
+ elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
+ if torch.isfinite(grad_norm):
+ optimizer.step()
+ optimizer.zero_grad()
+ scheduler.step()
+ info_dict["lr"] = optimizer.param_groups[0]['lr']
+ info_dict["grad_norm"] = grad_norm
+ return info_dict
+
+
+def log_per_step(writer, info_dict):
+ tag = info_dict["tag"]
+ epoch = info_dict.get('epoch', 0)
+ step = info_dict["step"]
+ batch_idx = info_dict["batch_idx"]
+ loss_dict = info_dict['loss_dict']
+ rank = int(os.environ.get('RANK', 0))
+
+ # only rank 0 write to tensorboard to avoid multi-process write
+ if writer is not None:
+ if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
+ (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
+ for k in ['epoch', 'lr', 'grad_norm']:
+ writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
+ for k, v in loss_dict.items():
+ writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
+
+ # TRAIN & CV, Shell log (stdout)
+ if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
+ log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
+ for name, value in loss_dict.items():
+ log_str += '{} {:.6f} '.format(name, value)
+ if tag == "TRAIN":
+ log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
+ info_dict["lr"], info_dict['grad_norm'])
+ log_str += ' rank {}'.format(rank)
+ logging.debug(log_str)
+
+
+def log_per_save(writer, info_dict):
+ tag = info_dict["tag"]
+ epoch = info_dict["epoch"]
+ step = info_dict["step"]
+ loss_dict = info_dict["loss_dict"]
+ lr = info_dict['lr']
+ rank = int(os.environ.get('RANK', 0))
+ logging.info(
+ 'Epoch {} Step {} CV info lr {} {} rank {}'.format(
+ epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
+
+ if writer is not None:
+ for k in ['epoch', 'lr']:
+ writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
+ for k, v in loss_dict.items():
+ writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
diff --git a/cross_lingual_prompt.wav b/cross_lingual_prompt.wav
new file mode 100644
index 0000000000000000000000000000000000000000..35d6d4eed08b24ca379fea873c4925bdb0b58d1e
Binary files /dev/null and b/cross_lingual_prompt.wav differ
diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..10206e69a7a80d9c9bf4594d7bf18c2aaa6ab352
--- /dev/null
+++ b/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml
@@ -0,0 +1,197 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.TransformerLM
+ text_encoder_input_size: !ref
+ llm_input_size: !ref
+ llm_output_size: !ref
+ text_token_size: 51866
+ speech_token_size: 4096
+ length_normalized_loss: True
+ lsm_weight: 0
+ spk_embed_dim: !ref
+ text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+ input_size: !ref
+ output_size: 1024
+ attention_heads: 8
+ linear_units: 2048
+ num_blocks: 3
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0
+ normalize_before: True
+ input_layer: 'linear'
+ pos_enc_layer_type: 'rel_pos_espnet'
+ selfattention_layer_type: 'rel_selfattn'
+ use_cnn_module: False
+ macaron_style: False
+ use_dynamic_chunk: False
+ use_dynamic_left_chunk: False
+ static_chunk_size: 1
+ llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
+ input_size: !ref
+ output_size: !ref
+ attention_heads: 8
+ linear_units: 2048
+ num_blocks: 7
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0
+ input_layer: 'linear_legacy'
+ pos_enc_layer_type: 'rel_pos_espnet'
+ selfattention_layer_type: 'rel_selfattn'
+ static_chunk_size: 1
+
+flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
+ input_size: 512
+ output_size: 80
+ spk_embed_dim: !ref
+ output_type: 'mel'
+ vocab_size: 4096
+ input_frame_rate: 50
+ only_mask_loss: True
+ encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+ output_size: 512
+ attention_heads: 8
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ normalize_before: True
+ input_layer: 'linear'
+ pos_enc_layer_type: 'rel_pos_espnet'
+ selfattention_layer_type: 'rel_selfattn'
+ input_size: 512
+ use_cnn_module: False
+ macaron_style: False
+ length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
+ channels: 80
+ sampling_ratios: [1, 1, 1, 1]
+ decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
+ in_channels: 240
+ n_spks: 1
+ spk_emb_dim: 80
+ cfm_params: !new:omegaconf.DictConfig
+ content:
+ sigma_min: 1e-06
+ solver: 'euler'
+ t_scheduler: 'cosine'
+ training_cfg_rate: 0.2
+ inference_cfg_rate: 0.7
+ reg_loss_type: 'l1'
+ estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
+ in_channels: 320
+ out_channels: 80
+ channels: [256, 256]
+ dropout: 0
+ attention_head_dim: 64
+ n_blocks: 4
+ num_mid_blocks: 12
+ num_heads: 8
+ act_fn: 'gelu'
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+ in_channels: 80
+ base_channels: 512
+ nb_harmonics: 8
+ sampling_rate: !ref
+ nsf_alpha: 0.1
+ nsf_sigma: 0.003
+ nsf_voiced_threshold: 10
+ upsample_rates: [8, 8]
+ upsample_kernel_sizes: [16, 16]
+ istft_params:
+ n_fft: 16
+ hop_len: 4
+ resblock_kernel_sizes: [3, 7, 11]
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+ source_resblock_kernel_sizes: [7, 11]
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+ lrelu_slope: 0.1
+ audio_limit: 0.99
+ f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+ num_class: 1
+ in_channels: 80
+ cond_channels: 512
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer
+ multilingual: True
+ num_languages: 100
+ language: 'en'
+ task: 'transcribe'
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+ get_tokenizer: !ref
+ allowed_special: !ref
+filter: !name:cosyvoice.dataset.processor.filter
+ max_length: 40960
+ min_length: 0
+ token_max_length: 200
+ token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+ resample_rate: !ref
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+ n_fft: 1024
+ num_mels: 80
+ sampling_rate: !ref
+ hop_size: 256
+ win_size: 1024
+ fmin: 0
+ fmax: 8000
+ center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+ feat_extractor: !ref
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+ normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+ shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+ sort_size: 500 # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+ batch_type: 'dynamic'
+ max_frames_in_batch: 12000
+padding: !name:cosyvoice.dataset.processor.padding
+
+# dataset processor pipeline
+data_pipeline: [
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+]
+
+# train conf
+train_conf:
+ optim: adam
+ optim_conf:
+ lr: 0.002 # change to 0.001 if you want to train flow from scratch
+ scheduler: warmuplr
+ scheduler_conf:
+ warmup_steps: 25000
+ max_epoch: 200
+ grad_clip: 5
+ accum_grad: 2
+ log_interval: 100
+ save_per_step: -1
\ No newline at end of file
diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cc5eee088d053314d2054cc6978e6897387692f1
--- /dev/null
+++ b/examples/libritts/cosyvoice/conf/cosyvoice.yaml
@@ -0,0 +1,197 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.TransformerLM
+ text_encoder_input_size: !ref
+ llm_input_size: !ref
+ llm_output_size: !ref
+ text_token_size: 51866
+ speech_token_size: 4096
+ length_normalized_loss: True
+ lsm_weight: 0
+ spk_embed_dim: !ref
+ text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+ input_size: !ref
+ output_size: 1024
+ attention_heads: 16
+ linear_units: 4096
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0
+ normalize_before: True
+ input_layer: 'linear'
+ pos_enc_layer_type: 'rel_pos_espnet'
+ selfattention_layer_type: 'rel_selfattn'
+ use_cnn_module: False
+ macaron_style: False
+ use_dynamic_chunk: False
+ use_dynamic_left_chunk: False
+ static_chunk_size: 1
+ llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
+ input_size: !ref
+ output_size: !ref
+ attention_heads: 16
+ linear_units: 4096
+ num_blocks: 14
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0
+ input_layer: 'linear_legacy'
+ pos_enc_layer_type: 'rel_pos_espnet'
+ selfattention_layer_type: 'rel_selfattn'
+ static_chunk_size: 1
+
+flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
+ input_size: 512
+ output_size: 80
+ spk_embed_dim: !ref
+ output_type: 'mel'
+ vocab_size: 4096
+ input_frame_rate: 50
+ only_mask_loss: True
+ encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+ output_size: 512
+ attention_heads: 8
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ normalize_before: True
+ input_layer: 'linear'
+ pos_enc_layer_type: 'rel_pos_espnet'
+ selfattention_layer_type: 'rel_selfattn'
+ input_size: 512
+ use_cnn_module: False
+ macaron_style: False
+ length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
+ channels: 80
+ sampling_ratios: [1, 1, 1, 1]
+ decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
+ in_channels: 240
+ n_spks: 1
+ spk_emb_dim: 80
+ cfm_params: !new:omegaconf.DictConfig
+ content:
+ sigma_min: 1e-06
+ solver: 'euler'
+ t_scheduler: 'cosine'
+ training_cfg_rate: 0.2
+ inference_cfg_rate: 0.7
+ reg_loss_type: 'l1'
+ estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
+ in_channels: 320
+ out_channels: 80
+ channels: [256, 256]
+ dropout: 0
+ attention_head_dim: 64
+ n_blocks: 4
+ num_mid_blocks: 12
+ num_heads: 8
+ act_fn: 'gelu'
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+ in_channels: 80
+ base_channels: 512
+ nb_harmonics: 8
+ sampling_rate: !ref
+ nsf_alpha: 0.1
+ nsf_sigma: 0.003
+ nsf_voiced_threshold: 10
+ upsample_rates: [8, 8]
+ upsample_kernel_sizes: [16, 16]
+ istft_params:
+ n_fft: 16
+ hop_len: 4
+ resblock_kernel_sizes: [3, 7, 11]
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+ source_resblock_kernel_sizes: [7, 11]
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+ lrelu_slope: 0.1
+ audio_limit: 0.99
+ f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+ num_class: 1
+ in_channels: 80
+ cond_channels: 512
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer
+ multilingual: True
+ num_languages: 100
+ language: 'en'
+ task: 'transcribe'
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+ get_tokenizer: !ref
+ allowed_special: !ref
+filter: !name:cosyvoice.dataset.processor.filter
+ max_length: 40960
+ min_length: 0
+ token_max_length: 200
+ token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+ resample_rate: !ref
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+ n_fft: 1024
+ num_mels: 80
+ sampling_rate: !ref
+ hop_size: 256
+ win_size: 1024
+ fmin: 0
+ fmax: 8000
+ center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+ feat_extractor: !ref
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+ normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+ shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+ sort_size: 500 # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+ batch_type: 'dynamic'
+ max_frames_in_batch: 2000
+padding: !name:cosyvoice.dataset.processor.padding
+
+# dataset processor pipeline
+data_pipeline: [
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref ,
+]
+
+# train conf
+train_conf:
+ optim: adam
+ optim_conf:
+ lr: 0.001
+ scheduler: warmuplr
+ scheduler_conf:
+ warmup_steps: 2500
+ max_epoch: 200
+ grad_clip: 5
+ accum_grad: 2
+ log_interval: 100
+ save_per_step: -1
\ No newline at end of file
diff --git a/examples/libritts/cosyvoice/conf/ds_stage2.json b/examples/libritts/cosyvoice/conf/ds_stage2.json
new file mode 100644
index 0000000000000000000000000000000000000000..2b2de3df7fe50968269ecafbf1f4ab9eb210f322
--- /dev/null
+++ b/examples/libritts/cosyvoice/conf/ds_stage2.json
@@ -0,0 +1,42 @@
+{
+ "train_micro_batch_size_per_gpu": 1,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 100,
+ "gradient_clipping": 5,
+ "fp16": {
+ "enabled": false,
+ "auto_cast": false,
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 256,
+ "hysteresis": 2,
+ "consecutive_hysteresis": false,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": false
+ },
+ "zero_force_ds_cpu_optimizer": false,
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "allgather_partitions": true,
+ "allgather_bucket_size": 5e8,
+ "overlap_comm": false,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "contiguous_gradients" : true
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": 0.001,
+ "weight_decay": 0.0001,
+ "torch_adam": true,
+ "adam_w_mode": true
+ }
+ }
+}
\ No newline at end of file
diff --git a/examples/libritts/cosyvoice/cosyvoice b/examples/libritts/cosyvoice/cosyvoice
new file mode 120000
index 0000000000000000000000000000000000000000..3903806279edbcc2f0ac36e08e8d208228037871
--- /dev/null
+++ b/examples/libritts/cosyvoice/cosyvoice
@@ -0,0 +1 @@
+../../../cosyvoice
\ No newline at end of file
diff --git a/examples/libritts/cosyvoice/local/download_and_untar.sh b/examples/libritts/cosyvoice/local/download_and_untar.sh
new file mode 100755
index 0000000000000000000000000000000000000000..cd32fb6b989d7229272f1066a75a1688df2bf06e
--- /dev/null
+++ b/examples/libritts/cosyvoice/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/bin/bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] "
+ echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo " can be one of: dev-clean, test-clean, dev-other, test-other,"
+ echo " train-clean-100, train-clean-360, train-other-500."
+ exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected to be one of $list, but got '$part'"
+ exit 1
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0
+fi
+
+
+# sizes of the archive files in bytes. This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release. Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+ size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tar.gz
+ else
+ echo "$data/$part.tar.gz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+ if ! which wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1
+ fi
+ full_url=$url/$part.tar.gz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ if ! wget -P $data --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1
+ fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+ echo "$0: error un-tarring archive $data/$part.tar.gz"
+ exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+ rm $data/$part.tar.gz
+fi
diff --git a/examples/libritts/cosyvoice/local/prepare_data.py b/examples/libritts/cosyvoice/local/prepare_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..92482261fd26843e4881f987e8c5ff8c66de9163
--- /dev/null
+++ b/examples/libritts/cosyvoice/local/prepare_data.py
@@ -0,0 +1,51 @@
+import argparse
+import logging
+import glob
+import os
+from tqdm import tqdm
+
+
+logger = logging.getLogger()
+
+def main():
+ wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
+
+ utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
+ for wav in tqdm(wavs):
+ txt = wav.replace('.wav', '.normalized.txt')
+ if not os.path.exists(txt):
+ logger.warning('{} do not exsist'.format(txt))
+ continue
+ with open(txt) as f:
+ content = ''.join(l.replace('\n', '') for l in f.readline())
+ utt = os.path.basename(wav).replace('.wav', '')
+ spk = utt.split('_')[0]
+ utt2wav[utt] = wav
+ utt2text[utt] = content
+ utt2spk[utt] = spk
+ if spk not in spk2utt:
+ spk2utt[spk] = []
+ spk2utt[spk].append(utt)
+
+ with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
+ for k, v in utt2wav.items():
+ f.write('{} {}\n'.format(k, v))
+ with open('{}/text'.format(args.des_dir), 'w') as f:
+ for k, v in utt2text.items():
+ f.write('{} {}\n'.format(k, v))
+ with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
+ for k, v in utt2spk.items():
+ f.write('{} {}\n'.format(k, v))
+ with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
+ for k, v in spk2utt.items():
+ f.write('{} {}\n'.format(k, ' '.join(v)))
+ return
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--src_dir',
+ type=str)
+ parser.add_argument('--des_dir',
+ type=str)
+ args = parser.parse_args()
+ main()
diff --git a/examples/libritts/cosyvoice/path.sh b/examples/libritts/cosyvoice/path.sh
new file mode 100644
index 0000000000000000000000000000000000000000..513f4ebdff883299147d347fbbef8f0397ad93f4
--- /dev/null
+++ b/examples/libritts/cosyvoice/path.sh
@@ -0,0 +1,3 @@
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=../../../:../../../third_party/AcademiCodec:../../../third_party/Matcha-TTS:$PYTHONPATH
diff --git a/examples/libritts/cosyvoice/run.sh b/examples/libritts/cosyvoice/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..96eca9b9026bc23945e70c99ff3c0a759887d883
--- /dev/null
+++ b/examples/libritts/cosyvoice/run.sh
@@ -0,0 +1,105 @@
+#!/bin/bash
+# Copyright 2024 Alibaba Inc. All Rights Reserved.
+. ./path.sh || exit 1;
+
+stage=-1
+stop_stage=3
+
+data_url=www.openslr.org/resources/60
+data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
+pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "Data Download"
+ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+ local/download_and_untar.sh ${data_dir} ${data_url} ${part}
+ done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
+ for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+ mkdir -p data/$x
+ python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
+ for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+ tools/extract_embedding.py --dir data/$x \
+ --onnx_path $pretrained_model_dir/campplus.onnx
+ done
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
+ for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+ tools/extract_speech_token.py --dir data/$x \
+ --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
+ done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
+ for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+ mkdir -p data/$x/parquet
+ tools/make_parquet_list.py --num_utts_per_parquet 1000 \
+ --num_processes 10 \
+ --src_dir data/$x \
+ --des_dir data/$x/parquet
+ done
+fi
+
+# inference
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Run inference. Please make sure utt in tts_text is in prompt_data"
+ for mode in sft zero_shot; do
+ python cosyvoice/bin/inference.py --mode $mode \
+ --gpu 0 \
+ --config conf/cosyvoice.yaml \
+ --prompt_data data/test-clean/parquet/data.list \
+ --prompt_utt2data data/test-clean/parquet/utt2data.list \
+ --tts_text `pwd`/tts_text.json \
+ --llm_model $pretrained_model_dir/llm.pt \
+ --flow_model $pretrained_model_dir/flow.pt \
+ --hifigan_model $pretrained_model_dir/hift.pt \
+ --result_dir `pwd`/exp/cosyvoice/test-clean/$mode
+ done
+fi
+
+# train llm
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+job_id=1986
+dist_backend="nccl"
+num_workers=2
+prefetch=100
+train_engine=torch_ddp
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
+ if [ $train_engine == 'deepspeed' ]; then
+ echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
+ fi
+ cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
+ cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
+ for model in llm; do
+ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
+ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
+ cosyvoice/bin/train.py \
+ --train_engine $train_engine \
+ --config conf/cosyvoice.yaml \
+ --train_data data/train.data.list \
+ --cv_data data/dev.data.list \
+ --model $model \
+ --checkpoint $pretrained_model_dir/$model.pt \
+ --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
+ --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
+ --ddp.dist_backend $dist_backend \
+ --num_workers ${num_workers} \
+ --prefetch ${prefetch} \
+ --pin_memory \
+ --deepspeed_config ./conf/ds_stage2.json \
+ --deepspeed.save_states model+optimizer
+ done
+fi
\ No newline at end of file
diff --git a/examples/libritts/cosyvoice/tools b/examples/libritts/cosyvoice/tools
new file mode 120000
index 0000000000000000000000000000000000000000..c92f4172dfdf1dffa4f76ab0d6bf93579b092e09
--- /dev/null
+++ b/examples/libritts/cosyvoice/tools
@@ -0,0 +1 @@
+../../../tools
\ No newline at end of file
diff --git a/examples/libritts/cosyvoice/tts_text.json b/examples/libritts/cosyvoice/tts_text.json
new file mode 100644
index 0000000000000000000000000000000000000000..9f3e8d9f7326b7e2eb9b9cbcfc9611d8f5c8bd9d
--- /dev/null
+++ b/examples/libritts/cosyvoice/tts_text.json
@@ -0,0 +1,5 @@
+{
+ "1089_134686_000002_000000": [
+ "hello, my name is Jack. What is your name?"
+ ]
+}
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ddd7c74103eec1ff84a271f66efa15c15bb5ebb8
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,27 @@
+--extra-index-url https://download.pytorch.org/whl/cu118
+conformer==0.3.2
+deepspeed==0.14.2
+diffusers==0.27.2
+gdown==5.1.0
+gradio==4.32.2
+grpcio==1.57.0
+grpcio-tools==1.57.0
+hydra-core==1.3.2
+HyperPyYAML==1.2.2
+inflect==6.0.2
+librosa==0.10.2
+lightning==2.2.4
+matplotlib==3.7.5
+modelscope==1.15.0
+networkx==3.1
+omegaconf==2.3.0
+onnxruntime-gpu==1.16.0
+openai-whisper==20231117
+protobuf==4.25
+pydantic==2.7.0
+rich==13.7.1
+soundfile==0.12.1
+tensorboard==2.14.0
+torch==2.0.1
+torchaudio==2.0.2
+wget==3.2
\ No newline at end of file
diff --git a/runtime/python/Dockerfile b/runtime/python/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..70bb24fc8166dd0f7e928d196b0bd6dd0fd8dda7
--- /dev/null
+++ b/runtime/python/Dockerfile
@@ -0,0 +1,15 @@
+FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /opt/CosyVoice
+
+RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
+RUN apt-get update -y
+RUN apt-get -y install python3-dev cmake python3-pip git
+# install torch takes a long time, cache it in case we may change requirements.txt
+# RUN git clone --depth 1 https://github.com/FunAudioLLM/CosyVoice.git
+ADD CosyVoice.tar .
+RUN mv CosyVoice_dockerfile CosyVoice
+RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
+RUN cd CosyVoice/runtime/python && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
+CMD ["/bin/bash", "-c", "cd /opt/CosyVoice/CosyVoice/runtime/python && . ./path/sh && python3 server.py --port 50000 --max_conc 4 --model_dir speech_tts/CosyVoice-300M && sleep infinity"]
\ No newline at end of file
diff --git a/runtime/python/client.py b/runtime/python/client.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7d30c8d6cdcfac913e61bf3d94f967df41260c3
--- /dev/null
+++ b/runtime/python/client.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+import logging
+import argparse
+import torchaudio
+import cosyvoice_pb2
+import cosyvoice_pb2_grpc
+import grpc
+import torch
+import numpy as np
+from cosyvoice.utils.file_utils import load_wav
+
+
+def main():
+ with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel:
+ stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel)
+ request = cosyvoice_pb2.Request()
+ if args.mode == 'sft':
+ logging.info('send sft request')
+ sft_request = cosyvoice_pb2.sftRequest()
+ sft_request.spk_id = args.spk_id
+ sft_request.tts_text = args.tts_text
+ request.sft_request.CopyFrom(sft_request)
+ elif args.mode == 'zero_shot':
+ logging.info('send zero_shot request')
+ zero_shot_request = cosyvoice_pb2.zeroshotRequest()
+ zero_shot_request.tts_text = args.tts_text
+ zero_shot_request.prompt_text = args.prompt_text
+ prompt_speech = load_wav(args.prompt_wav, 16000)
+ zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
+ request.zero_shot_request.CopyFrom(zero_shot_request)
+ elif args.mode == 'cross_lingual':
+ logging.info('send cross_lingual request')
+ cross_lingual_request = cosyvoice_pb2.crosslingualRequest()
+ cross_lingual_request.tts_text = args.tts_text
+ prompt_speech = load_wav(args.prompt_wav, 16000)
+ cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
+ request.cross_lingual_request.CopyFrom(cross_lingual_request)
+ else:
+ logging.info('send instruct request')
+ instruct_request = cosyvoice_pb2.instructRequest()
+ instruct_request.tts_text = args.tts_text
+ instruct_request.spk_id = args.spk_id
+ instruct_request.instruct_text = args.instruct_text
+ request.instruct_request.CopyFrom(instruct_request)
+
+ response = stub.Inference(request)
+ logging.info('save response to {}'.format(args.tts_wav))
+ tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
+ torchaudio.save(args.tts_wav, tts_speech, target_sr)
+ logging.info('get response')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--host',
+ type=str,
+ default='0.0.0.0')
+ parser.add_argument('--port',
+ type=int,
+ default='50000')
+ parser.add_argument('--mode',
+ default='sft',
+ choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
+ help='request mode')
+ parser.add_argument('--tts_text',
+ type=str,
+ default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
+ parser.add_argument('--spk_id',
+ type=str,
+ default='中文女')
+ parser.add_argument('--prompt_text',
+ type=str,
+ default='希望你以后能够做的比我还好呦。')
+ parser.add_argument('--prompt_wav',
+ type=str,
+ default='../../zero_shot_prompt.wav')
+ parser.add_argument('--instruct_text',
+ type=str,
+ default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
+ parser.add_argument('--tts_wav',
+ type=str,
+ default='demo.wav')
+ args = parser.parse_args()
+ prompt_sr, target_sr = 16000, 22050
+ main()
diff --git a/runtime/python/cosyvoice.proto b/runtime/python/cosyvoice.proto
new file mode 100644
index 0000000000000000000000000000000000000000..cb4a6e0892af35e775e459b970038a4292180fc0
--- /dev/null
+++ b/runtime/python/cosyvoice.proto
@@ -0,0 +1,56 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+syntax = "proto3";
+
+package cosyvoice;
+option go_package = "protos/";
+
+service CosyVoice{
+ rpc Inference(Request) returns (Response) {}
+}
+
+message Request{
+ oneof RequestPayload {
+ sftRequest sft_request = 1;
+ zeroshotRequest zero_shot_request = 2;
+ crosslingualRequest cross_lingual_request = 3;
+ instructRequest instruct_request = 4;
+ }
+}
+
+message sftRequest{
+ string spk_id = 1;
+ string tts_text = 2;
+}
+
+message zeroshotRequest{
+ string tts_text = 1;
+ string prompt_text = 2;
+ bytes prompt_audio = 3;
+}
+
+message crosslingualRequest{
+ string tts_text = 1;
+ bytes prompt_audio = 2;
+}
+
+message instructRequest{
+ string tts_text = 1;
+ string spk_id = 2;
+ string instruct_text = 3;
+}
+
+message Response{
+ bytes tts_audio = 1;
+}
\ No newline at end of file
diff --git a/runtime/python/path.sh b/runtime/python/path.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f83d12bb0b51a1a249e46b0add6f7b4333d38303
--- /dev/null
+++ b/runtime/python/path.sh
@@ -0,0 +1,3 @@
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=../../:../../third_party/AcademiCodec:../../third_party/Matcha-TTS:$PYTHONPATH
diff --git a/runtime/python/server.py b/runtime/python/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..7641610685c874c294fad5e9fdedc24c3aaf4962
--- /dev/null
+++ b/runtime/python/server.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from concurrent import futures
+import argparse
+import cosyvoice_pb2
+import cosyvoice_pb2_grpc
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import grpc
+import torch
+import numpy as np
+from cosyvoice.cli.cosyvoice import CosyVoice
+
+logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
+ def __init__(self, args):
+ self.cosyvoice = CosyVoice(args.model_dir)
+ logging.info('grpc service initialized')
+
+ def Inference(self, request, context):
+ if request.HasField('sft_request'):
+ logging.info('get sft inference request')
+ model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
+ elif request.HasField('zero_shot_request'):
+ logging.info('get zero_shot inference request')
+ prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
+ prompt_speech_16k = prompt_speech_16k.float() / (2**15)
+ model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
+ elif request.HasField('cross_lingual_request'):
+ logging.info('get cross_lingual inference request')
+ prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
+ prompt_speech_16k = prompt_speech_16k.float() / (2**15)
+ model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
+ else:
+ logging.info('get instruct inference request')
+ model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
+
+ logging.info('send inference response')
+ response = cosyvoice_pb2.Response()
+ response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
+ return response
+
+def main():
+ grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
+ cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
+ grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
+ grpcServer.start()
+ logging.info("server listening on 0.0.0.0:{}".format(args.port))
+ grpcServer.wait_for_termination()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--port',
+ type=int,
+ default=50000)
+ parser.add_argument('--max_conc',
+ type=int,
+ default=4)
+ parser.add_argument('--model_dir',
+ type=str,
+ required=True,
+ default='speech_tts/CosyVoice-300M',
+ help='local path or modelscope repo id')
+ args = parser.parse_args()
+ main()
diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py
new file mode 100755
index 0000000000000000000000000000000000000000..02fa2f6286309d47b8cfb2335ac463b1dbc49ef0
--- /dev/null
+++ b/tools/extract_embedding.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import torch
+import torchaudio
+from tqdm import tqdm
+import onnxruntime
+import torchaudio.compliance.kaldi as kaldi
+
+
+def main(args):
+ utt2wav, utt2spk = {}, {}
+ with open('{}/wav.scp'.format(args.dir)) as f:
+ for l in f:
+ l = l.replace('\n', '').split()
+ utt2wav[l[0]] = l[1]
+ with open('{}/utt2spk'.format(args.dir)) as f:
+ for l in f:
+ l = l.replace('\n', '').split()
+ utt2spk[l[0]] = l[1]
+
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ["CPUExecutionProvider"]
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
+
+ utt2embedding, spk2embedding = {}, {}
+ for utt in tqdm(utt2wav.keys()):
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
+ if sample_rate != 16000:
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
+ feat = kaldi.fbank(audio,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000)
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+ utt2embedding[utt] = embedding
+ spk = utt2spk[utt]
+ if spk not in spk2embedding:
+ spk2embedding[spk] = []
+ spk2embedding[spk].append(embedding)
+
+ torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
+ torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dir',
+ type=str)
+ parser.add_argument('--onnx_path',
+ type=str)
+ args = parser.parse_args()
+ main(args)
diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py
new file mode 100755
index 0000000000000000000000000000000000000000..fac0b0bc18e273920147be3a4c525b6249e1f21b
--- /dev/null
+++ b/tools/extract_speech_token.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+import torch
+from tqdm import tqdm
+import onnxruntime
+import numpy as np
+import torchaudio
+import whisper
+
+
+def main(args):
+ utt2wav = {}
+ with open('{}/wav.scp'.format(args.dir)) as f:
+ for l in f:
+ l = l.replace('\n', '').split()
+ utt2wav[l[0]] = l[1]
+
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ["CUDAExecutionProvider"]
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
+
+ utt2speech_token = {}
+ for utt in tqdm(utt2wav.keys()):
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
+ if sample_rate != 16000:
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
+ if audio.shape[1] / 16000 > 30:
+ logging.warning('do not support extract speech token for audio longer than 30s')
+ speech_token = []
+ else:
+ feat = whisper.log_mel_spectrogram(audio, n_mels=128)
+ speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
+ ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ utt2speech_token[utt] = speech_token
+ torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dir',
+ type=str)
+ parser.add_argument('--onnx_path',
+ type=str)
+ args = parser.parse_args()
+ main(args)
diff --git a/tools/make_parquet_list.py b/tools/make_parquet_list.py
new file mode 100755
index 0000000000000000000000000000000000000000..399d9a6b57bb24579c674bc2f27f7688c867c1a5
--- /dev/null
+++ b/tools/make_parquet_list.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+import os
+import json
+from tqdm import tqdm
+import pandas as pd
+import multiprocessing
+import time
+import torch
+
+
+def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
+ start_time = time.time()
+ data_list = []
+ for utt in tqdm(utt_list):
+ data = open(utt2wav[utt], 'rb').read()
+ data_list.append(data)
+ wav_list = [utt2wav[utt] for utt in utt_list]
+ text_list = [utt2text[utt] for utt in utt_list]
+ spk_list = [utt2spk[utt] for utt in utt_list]
+ uttembedding_list = [utt2embedding[utt] for utt in utt_list]
+ spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
+ speech_token_list = [utt2speech_token[utt] for utt in utt_list]
+
+ # 保存到parquet,utt2parquet_file,spk2parquet_file
+ df = pd.DataFrame()
+ df['utt'] = utt_list
+ df['wav'] = wav_list
+ df['audio_data'] = data_list
+ df['text'] = text_list
+ df['spk'] = spk_list
+ df['utt_embedding'] = uttembedding_list
+ df['spk_embedding'] = spkembedding_list
+ df['speech_token'] = speech_token_list
+ df.to_parquet(parquet_file)
+ with open(utt2parquet_file, 'w') as f:
+ json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
+ with open(spk2parquet_file, 'w') as f:
+ json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
+ logging.info('spend time {}'.format(time.time() - start_time))
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_utts_per_parquet',
+ type=int,
+ default=1000,
+ help='num utts per parquet')
+ parser.add_argument('--num_processes',
+ type=int,
+ default=1,
+ help='num processes for make parquets')
+ parser.add_argument('--src_dir',
+ type=str)
+ parser.add_argument('--des_dir',
+ type=str)
+ args = parser.parse_args()
+
+ utt2wav, utt2text, utt2spk = {}, {}, {}
+ with open('{}/wav.scp'.format(args.src_dir)) as f:
+ for l in f:
+ l = l.replace('\n', '').split()
+ utt2wav[l[0]] = l[1]
+ with open('{}/text'.format(args.src_dir)) as f:
+ for l in f:
+ l = l.replace('\n', '').split()
+ utt2text[l[0]] = ' '.join(l[1:])
+ with open('{}/utt2spk'.format(args.src_dir)) as f:
+ for l in f:
+ l = l.replace('\n', '').split()
+ utt2spk[l[0]] = l[1]
+ utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
+ spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
+ utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
+ utts = list(utt2wav.keys())
+
+ # Using process pool to speedup
+ pool = multiprocessing.Pool(processes=args.num_processes)
+ parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
+ for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
+ parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
+ utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
+ spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
+ parquet_list.append(parquet_file)
+ utt2parquet_list.append(utt2parquet_file)
+ spk2parquet_list.append(spk2parquet_file)
+ pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
+ pool.close()
+ pool.join()
+
+ with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
+ open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
+ open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
+ for name in parquet_list:
+ f1.write(name + '\n')
+ for name in utt2parquet_list:
+ f2.write(name + '\n')
+ for name in spk2parquet_list:
+ f3.write(name + '\n')
diff --git a/webui.py b/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cbc16d87f8a011b17039dcaff2fa13acd56005a
--- /dev/null
+++ b/webui.py
@@ -0,0 +1,186 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/third_party/AcademiCodec'.format(ROOT_DIR))
+sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
+
+import argparse
+import gradio as gr
+import numpy as np
+import torch
+import torchaudio
+import random
+import librosa
+
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+
+from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.utils.file_utils import load_wav
+
+logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+def generate_seed():
+ seed = random.randint(1, 100000000)
+ return {
+ "__type__": "update",
+ "value": seed
+ }
+
+def set_all_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+max_val = 0.8
+def postprocess(speech, top_db=60, hop_length=220, win_length=440):
+ speech, _ = librosa.effects.trim(
+ speech, top_db=top_db,
+ frame_length=win_length,
+ hop_length=hop_length
+ )
+ if speech.abs().max() > max_val:
+ speech = speech / speech.abs().max() * max_val
+ speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
+ return speech
+
+inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
+instruct_dict = {'预训练音色': '1. 选择预训练音色\n2.点击生成音频按钮',
+ '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3.点击生成音频按钮',
+ '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,若同时提供,优先选择prompt音频文件\n2.点击生成音频按钮',
+ '自然语言控制': '1. 输入instruct文本\n2.点击生成音频按钮'}
+def change_instruction(mode_checkbox_group):
+ return instruct_dict[mode_checkbox_group]
+
+def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed):
+ if prompt_wav_upload is not None:
+ prompt_wav = prompt_wav_upload
+ elif prompt_wav_record is not None:
+ prompt_wav = prompt_wav_record
+ else:
+ prompt_wav = None
+ # if instruct mode, please make sure that model is speech_tts/CosyVoice-300M-Instruct and not cross_lingual mode
+ if mode_checkbox_group in ['自然语言控制']:
+ if cosyvoice.frontend.instruct is False:
+ gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用speech_tts/CosyVoice-300M-Instruct模型'.format(args.model_dir))
+ return (target_sr, default_data)
+ if instruct_text == '':
+ gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
+ return (target_sr, default_data)
+ if prompt_wav is not None or prompt_text != '':
+ gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
+ # if cross_lingual mode, please make sure that model is speech_tts/CosyVoice-300M and tts_text prompt_text are different language
+ if mode_checkbox_group in ['跨语种复刻']:
+ if cosyvoice.frontend.instruct is True:
+ gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用speech_tts/CosyVoice-300M模型'.format(args.model_dir))
+ return (target_sr, default_data)
+ if instruct_text != '':
+ gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
+ if prompt_wav is None:
+ gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
+ return (target_sr, default_data)
+ gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
+ # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
+ if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
+ if prompt_wav is None:
+ gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
+ return (target_sr, default_data)
+ if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
+ gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
+ return (target_sr, default_data)
+ # sft mode only use sft_dropdown
+ if mode_checkbox_group in ['预训练音色']:
+ if instruct_text != '' or prompt_wav is not None or prompt_text != '':
+ gr.Info('您正在使用预训练音色模式,prompt文本/prompt音频/instruct文本会被忽略!')
+ # zero_shot mode only use prompt_wav prompt text
+ if mode_checkbox_group in ['3s极速复刻']:
+ if prompt_text == '':
+ gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
+ return (target_sr, default_data)
+ if instruct_text != '':
+ gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
+
+ if mode_checkbox_group == '预训练音色':
+ logging.info('get sft inference request')
+ set_all_random_seed(seed)
+ output = cosyvoice.inference_sft(tts_text, sft_dropdown)
+ elif mode_checkbox_group == '3s极速复刻':
+ logging.info('get zero_shot inference request')
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
+ set_all_random_seed(seed)
+ output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
+ elif mode_checkbox_group == '跨语种复刻':
+ logging.info('get cross_lingual inference request')
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
+ set_all_random_seed(seed)
+ output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
+ else:
+ logging.info('get instruct inference request')
+ set_all_random_seed(seed)
+ output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
+ audio_data = output['tts_speech'].numpy().flatten()
+ return (target_sr, audio_data)
+
+def main():
+ with gr.Blocks() as demo:
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/speech_tts/CosyVoice-300M) [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/speech_tts/CosyVoice-300M-Instruct) [CosyVoice-300M-SFT](https://www.modelscope.cn/models/speech_tts/CosyVoice-300M-SFT)")
+ gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
+
+ tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
+
+ with gr.Row():
+ mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
+ instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
+ sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
+ with gr.Column(scale=0.25):
+ seed_button = gr.Button(value="\U0001F3B2")
+ seed = gr.Number(value=0, label="随机推理种子")
+
+ with gr.Row():
+ prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
+ prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
+ prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='')
+ instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.", value='')
+
+ generate_button = gr.Button("生成音频")
+
+ audio_output = gr.Audio(label="合成音频")
+
+ seed_button.click(generate_seed, inputs=[], outputs=seed)
+ generate_button.click(generate_audio,
+ inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed],
+ outputs=[audio_output])
+ mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
+ demo.queue(max_size=4, default_concurrency_limit=2)
+ demo.launch(server_port=args.port)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--port',
+ type=int,
+ default=8000)
+ parser.add_argument('--model_dir',
+ type=str,
+ default='speech_tts/CosyVoice-300M',
+ help='local path or modelscope repo id')
+ args = parser.parse_args()
+ cosyvoice = CosyVoice(args.model_dir)
+ sft_spk = cosyvoice.list_avaliable_spks()
+ prompt_sr, target_sr = 16000, 22050
+ default_data = np.zeros(target_sr)
+ main()
diff --git a/zero_shot_prompt.wav b/zero_shot_prompt.wav
new file mode 100644
index 0000000000000000000000000000000000000000..25fbf592f2a5efa3d966ff64084c75a43be4cc1e
Binary files /dev/null and b/zero_shot_prompt.wav differ