diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..e464d502ddae5dc6ed7e1c13176a3f90fa8c4738 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,7 @@ +.git +.github +results +data +*.filelist +/data_server/target +checkpoints diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..131480d9f9a104a38341852aabdedc351431cc58 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,50 @@ +name: "🕷️ Bug report" +description: Report errors or unexpected behavior +labels: + - bug +body: + - type: checkboxes + attributes: + label: Self Checks + description: "To make sure we get to you in time, please check the following :)" + options: + - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/fishaudio/fish-speech/discussions). + required: true + - label: I have searched for existing issues [search for existing issues](https://github.com/fishaudio/fish-speech/issues), including closed ones. + required: true + - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)). + required: true + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + required: true + - label: "Please do not modify this template :) and fill in all the required fields." + required: true + - type: dropdown + attributes: + label: Cloud or Self Hosted + multiple: true + options: + - Cloud + - Self Hosted (Docker) + - Self Hosted (Source) + validations: + required: true + - type: textarea + attributes: + label: Steps to reproduce + description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks. + placeholder: Having detailed steps helps us reproduce the bug. + validations: + required: true + - type: textarea + attributes: + label: ✔️ Expected Behavior + placeholder: What were you expecting? + validations: + required: false + + - type: textarea + attributes: + label: ❌ Actual Behavior + placeholder: What happened instead? + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..886c238bf64f361460b5eba3906d94de73b092f6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: "\U0001F4E7 Discussions" + url: https://github.com/fishaudio/fish-speech/discussions + about: General discussions and request help from the community diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..4f4eee9090384d28018b842736dc5be246ce7d41 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,40 @@ +name: "⭐ Feature or enhancement request" +description: Propose something new. +labels: + - enhancement +body: + - type: checkboxes + attributes: + label: Self Checks + description: "To make sure we get to you in time, please check the following :)" + options: + - label: I have searched for existing issues [search for existing issues]([https://github.com/langgenius/dify/issues](https://github.com/fishaudio/fish-speech/issues)), including closed ones. + required: true + - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)). + required: true + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + required: true + - label: "Please do not modify this template :) and fill in all the required fields." + required: true + - type: textarea + attributes: + label: 1. Is this request related to a challenge you're experiencing? Tell me about your story. + placeholder: Please describe the specific scenario or problem you're facing as clearly as possible. For instance "I was trying to use [feature] for [specific task], and [what happened]... It was frustrating because...." + validations: + required: true + - type: textarea + attributes: + label: 2. Additional context or comments + placeholder: (Any other information, comments, documentations, links, or screenshots that would provide more clarity. This is the place to add anything else not covered above.) + validations: + required: false + - type: checkboxes + attributes: + label: 3. Can you help us with this feature? + description: Let us know! This is not a commitment, but a starting point for collaboration. + options: + - label: I am interested in contributing to this feature. + required: false + - type: markdown + attributes: + value: Please limit one request per issue. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..91c6c22a1806d27511bd8cfc8c31cb3deb4379aa --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,7 @@ +**Is this PR adding new feature or fix a BUG?** + +Add feature / Fix BUG. + +**Is this pull request related to any issue? If yes, please link the issue.** + +#xxx diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml new file mode 100644 index 0000000000000000000000000000000000000000..6a2b7eff210e5c4593d0130633b539f6c7279ac0 --- /dev/null +++ b/.github/workflows/build-docker-image.yml @@ -0,0 +1,70 @@ +name: Build Image + +on: + push: + branches: + - main + tags: + - 'v*' + +jobs: + build: + runs-on: ubuntu-latest-16c64g + steps: + - uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Get Version + run: | + if [[ $GITHUB_REF == refs/tags/v* ]]; then + version=$(basename ${GITHUB_REF}) + else + version=nightly + fi + + echo "version=${version}" >> $GITHUB_ENV + echo "Current version: ${version}" + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_PAT }} + + - name: Build and Push Image + uses: docker/build-push-action@v6 + with: + context: . + file: dockerfile + platforms: linux/amd64 + push: true + tags: | + fishaudio/fish-speech:${{ env.version }} + fishaudio/fish-speech:latest + outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true + cache-from: type=registry,ref=fishaudio/fish-speech:latest + cache-to: type=inline + + - name: Build and Push Dev Image + uses: docker/build-push-action@v6 + with: + context: . + file: dockerfile.dev + platforms: linux/amd64 + push: true + build-args: | + VERSION=${{ env.version }} + BASE_IMAGE=fishaudio/fish-speech:${{ env.version }} + tags: | + fishaudio/fish-speech:${{ env.version }}-dev + fishaudio/fish-speech:latest-dev + outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true + cache-from: type=registry,ref=fishaudio/fish-speech:latest-dev + cache-to: type=inline + + - name: Push README to Dockerhub + uses: peter-evans/dockerhub-description@v4 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_PAT }} + repository: fishaudio/fish-speech diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000000000000000000000000000000000000..0967ec0b00c3e1c0392f1b54d9d6200b18c00b46 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,33 @@ +name: docs +on: + push: + branches: + - main + paths: + - 'docs/**' + - 'mkdocs.yml' + +permissions: + contents: write + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + - uses: actions/cache@v4 + with: + key: mkdocs-material-${{ env.cache_id }} + path: .cache + restore-keys: | + mkdocs-material- + - run: pip install -r docs/requirements.txt + - run: mkdocs gh-deploy --force diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000000000000000000000000000000000..47f4405c3309162333698f352d8012e2eee5d48c --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,25 @@ +name: Close inactive issues +on: + schedule: + - cron: "0 0 * * *" + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: 30 + days-before-pr-close: 30 + stale-pr-label: "stale" + stale-pr-message: "This PR is stale because it has been open for 30 days with no activity." + close-pr-message: "This PR was closed because it has been inactive for 30 days since being marked as stale." + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2acb6c2e2d0c64f15a3475c5f266268dd83dc05f --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +.DS_Store +.pgx.* +.pdm-python +/fish_speech.egg-info +__pycache__ +/results +/data +/*.test.sh +*.filelist +filelists +/fish_speech/text/cmudict_cache.pickle +/checkpoints +/.vscode +/data_server/target +/*.npy +/*.wav +/*.mp3 +/*.lab +/results +/data +/.idea +ffmpeg.exe +ffprobe.exe +asr-label* +/.cache +/fishenv +/.locale +/demo-audios +/references +/example +/faster_whisper diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3965156b613d25bf447c0f0370d8d77c45766f41 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +ci: + autoupdate_schedule: monthly + +repos: + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile=black] + + - repo: https://github.com/psf/black + rev: 24.8.0 + hooks: + - id: black + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + - id: check-yaml + - id: check-json + - id: mixed-line-ending + args: ['--fix=lf'] + - id: check-added-large-files + args: ['--maxkb=5000'] diff --git a/.project-root b/.project-root new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..015eb5de8569951255b2d66c251ee20fe9153ace --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,19 @@ +# Read the Docs configuration file for MkDocs projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + +mkdocs: + configuration: mkdocs.yml + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/requirements.txt diff --git a/API_FLAGS.txt b/API_FLAGS.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0be18de55d30dba96e259da553c5b2cb830fa7b --- /dev/null +++ b/API_FLAGS.txt @@ -0,0 +1,6 @@ +# --infer +# --api +--listen 0.0.0.0:8080 \ +--llama-checkpoint-path "checkpoints/fish-speech-1.4" \ +--decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ +--decoder-config-name firefly_gan_vq diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..067187a7ba0907a789b6efc234f616a3785ba1fd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,44 @@ +FROM python:3.12-slim-bookworm AS stage-1 +ARG TARGETARCH + +ARG HUGGINGFACE_MODEL=fish-speech-1.4 +ARG HF_ENDPOINT=https://huggingface.co + +WORKDIR /opt/fish-speech + +RUN set -ex \ + && pip install huggingface_hub \ + && HF_ENDPOINT=${HF_ENDPOINT} huggingface-cli download --resume-download fishaudio/${HUGGINGFACE_MODEL} --local-dir checkpoints/${HUGGINGFACE_MODEL} + +FROM python:3.12-slim-bookworm +ARG TARGETARCH + +ARG DEPENDENCIES=" \ + ca-certificates \ + libsox-dev \ + ffmpeg" + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + set -ex \ + && rm -f /etc/apt/apt.conf.d/docker-clean \ + && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \ + && apt-get update \ + && apt-get -y install --no-install-recommends ${DEPENDENCIES} \ + && echo "no" | dpkg-reconfigure dash + +WORKDIR /opt/fish-speech + +COPY . . + +RUN --mount=type=cache,target=/root/.cache,sharing=locked \ + set -ex \ + && pip install -e .[stable] + +COPY --from=stage-1 /opt/fish-speech/checkpoints /opt/fish-speech/checkpoints + +ENV GRADIO_SERVER_NAME="0.0.0.0" + +EXPOSE 7860 + +CMD ["./entrypoint.sh"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cbe5ad1670406e4402217edfb82d2c56af7e8631 --- /dev/null +++ b/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml new file mode 100644 index 0000000000000000000000000000000000000000..3054037de5fd4931b22be279d5c8d505be950519 --- /dev/null +++ b/docker-compose.dev.yml @@ -0,0 +1,16 @@ +version: '3.8' + +services: + fish-speech: + build: . + container_name: fish-speech + volumes: + - ./:/exp + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + command: tail -f /dev/null diff --git a/dockerfile.dev b/dockerfile.dev new file mode 100644 index 0000000000000000000000000000000000000000..2d07296ed428bb638ae2e93bf060d500ed8cdecb --- /dev/null +++ b/dockerfile.dev @@ -0,0 +1,33 @@ +ARG VERSION=dev +ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION} + +FROM ${BASE_IMAGE} + +ARG TOOLS=" \ + git \ + curl \ + build-essential \ + ffmpeg \ + libsm6 \ + libxext6 \ + libjpeg-dev \ + zlib1g-dev \ + aria2 \ + zsh \ + openssh-server \ + sudo \ + protobuf-compiler \ + cmake" + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + set -ex \ + && apt-get update \ + && apt-get -y install --no-install-recommends ${TOOLS} + +# Install oh-my-zsh so your terminal looks nice +RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended + +# Set zsh as default shell +RUN chsh -s /usr/bin/zsh +ENV SHELL=/usr/bin/zsh diff --git a/docs/CNAME b/docs/CNAME new file mode 100644 index 0000000000000000000000000000000000000000..d506fb8b394fa80f3d329ab8450dfc102e839bd1 --- /dev/null +++ b/docs/CNAME @@ -0,0 +1 @@ +speech.fish.audio diff --git a/docs/assets/figs/VS_1.jpg b/docs/assets/figs/VS_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..41a3f69992edcbbaa85a21695bdc33ff81dc10d6 Binary files /dev/null and b/docs/assets/figs/VS_1.jpg differ diff --git a/docs/assets/figs/VS_1_pt-BR.png b/docs/assets/figs/VS_1_pt-BR.png new file mode 100644 index 0000000000000000000000000000000000000000..d7cf5c85cb1cf98d9c716d03575eb0c74d53d572 Binary files /dev/null and b/docs/assets/figs/VS_1_pt-BR.png differ diff --git a/docs/assets/figs/diagram.png b/docs/assets/figs/diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..254b669c293428926e8d28d47471536d6eb76357 Binary files /dev/null and b/docs/assets/figs/diagram.png differ diff --git a/docs/assets/figs/diagrama.png b/docs/assets/figs/diagrama.png new file mode 100644 index 0000000000000000000000000000000000000000..140f926ad9dc3e3e494872f1ca7b7e3f24994c3b Binary files /dev/null and b/docs/assets/figs/diagrama.png differ diff --git a/docs/en/finetune.md b/docs/en/finetune.md new file mode 100644 index 0000000000000000000000000000000000000000..72cc6effb67e14cc24eec0db979b58d2ca664776 --- /dev/null +++ b/docs/en/finetune.md @@ -0,0 +1,125 @@ +# Fine-tuning + +Obviously, when you opened this page, you were not satisfied with the performance of the few-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset. + +In current version, you only need to finetune the 'LLAMA' part. + +## Fine-tuning LLAMA +### 1. Prepare the dataset + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 30.1-32.71.lab +│ └── 30.1-32.71.mp3 +└── SPK2 + ├── 38.79-40.85.lab + └── 38.79-40.85.mp3 +``` + +You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extensions `.lab`. + +!!! warning + It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this. + + ```bash + fap loudness-norm data-raw data --clean + ``` + + +### 2. Batch extraction of semantic tokens + +Make sure you have downloaded the VQGAN weights. If not, run the following command: + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +You can then run the following command to extract semantic tokens: + +```bash +python tools/vqgan/extract_vq.py data \ + --num-workers 1 --batch-size 16 \ + --config-name "firefly_gan_vq" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +!!! note + You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit. + For the VITS format, you can specify a file list using `--filelist xxx.list`. + +This command will create `.npy` files in the `data` directory, as shown below: + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 21.15-26.44.npy +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 27.51-29.98.npy +│ ├── 30.1-32.71.lab +│ ├── 30.1-32.71.mp3 +│ └── 30.1-32.71.npy +└── SPK2 + ├── 38.79-40.85.lab + ├── 38.79-40.85.mp3 + └── 38.79-40.85.npy +``` + +### 3. Pack the dataset into protobuf + +```bash +python tools/llama/build_dataset.py \ + --input "data" \ + --output "data/protos" \ + --text-extension .lab \ + --num-workers 16 +``` + +After the command finishes executing, you should see the `quantized-dataset-ft.protos` file in the `data` directory. + +### 4. Finally, fine-tuning with LoRA + +Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command: + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +Finally, you can start the fine-tuning by running the following command: + +```bash +python fish_speech/train.py --config-name text2semantic_finetune \ + project=$project \ + +lora@model.model.lora_config=r_8_alpha_16 +``` + +!!! note + You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`. + +!!! note + For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues. + +After training is complete, you can refer to the [inference](inference.md) section, and use `--speaker SPK1` to generate speech. + +!!! info + By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability. + If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting. + +After training, you need to convert the LoRA weights to regular weights before performing inference. + +```bash +python tools/llama/merge_lora.py \ + --lora-config r_8_alpha_16 \ + --base-weight checkpoints/fish-speech-1.4 \ + --lora-weight results/$project/checkpoints/step_000000010.ckpt \ + --output checkpoints/fish-speech-1.4-yth-lora/ +``` +!!! note + You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data. diff --git a/docs/en/index.md b/docs/en/index.md new file mode 100644 index 0000000000000000000000000000000000000000..e86c8853b942e8d578017b4e37ba85683d34d4be --- /dev/null +++ b/docs/en/index.md @@ -0,0 +1,133 @@ +# Introduction + +
+ +Discord + + +QQ + + +Docker + +
+ +!!! warning + We assume no responsibility for any illegal use of the codebase. Please refer to the local laws regarding DMCA (Digital Millennium Copyright Act) and other relevant laws in your area.
+ This codebase and all models are released under the CC-BY-NC-SA-4.0 license. + +

+ +

+ +## Requirements + +- GPU Memory: 4GB (for inference), 8GB (for fine-tuning) +- System: Linux, Windows + +## Windows Setup + +Professional Windows users may consider using WSL2 or Docker to run the codebase. + +```bash +# Create a python 3.10 virtual environment, you can also use virtualenv +conda create -n fish-speech python=3.10 +conda activate fish-speech + +# Install pytorch +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + +# Install fish-speech +pip3 install -e . + +# (Enable acceleration) Install triton-windows +pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl +``` + +Non-professional Windows users can consider the following basic methods to run the project without a Linux environment (with model compilation capabilities, i.e., `torch.compile`): + +1. Extract the project package. +2. Click `install_env.bat` to install the environment. +3. If you want to enable compilation acceleration, follow this step: + 1. Download the LLVM compiler from the following links: + - [LLVM-17.0.6 (Official Site Download)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) + - [LLVM-17.0.6 (Mirror Site Download)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) + - After downloading `LLVM-17.0.6-win64.exe`, double-click to install, select an appropriate installation location, and most importantly, check the `Add Path to Current User` option to add the environment variable. + - Confirm that the installation is complete. + 2. Download and install the Microsoft Visual C++ Redistributable to solve potential .dll missing issues: + - [MSVC++ 14.40.33810.0 Download](https://aka.ms/vs/17/release/vc_redist.x64.exe) + 3. Download and install Visual Studio Community Edition to get MSVC++ build tools and resolve LLVM's header file dependencies: + - [Visual Studio Download](https://visualstudio.microsoft.com/zh-hans/downloads/) + - After installing Visual Studio Installer, download Visual Studio Community 2022. + - As shown below, click the `Modify` button and find the `Desktop development with C++` option to select and download. + 4. Download and install [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64) +4. Double-click `start.bat` to open the training inference WebUI management interface. If needed, you can modify the `API_FLAGS` as prompted below. + +!!! info "Optional" + + Want to start the inference WebUI? + + Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows: + ``` + --infer + # --api + # --listen ... + ... + ``` + +!!! info "Optional" + + Want to start the API server? + + Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows: + + ``` + # --infer + --api + --listen ... + ... + ``` + +!!! info "Optional" + + Double-click `run_cmd.bat` to enter the conda/python command line environment of this project. + +## Linux Setup + +```bash +# Create a python 3.10 virtual environment, you can also use virtualenv +conda create -n fish-speech python=3.10 +conda activate fish-speech + +# Install pytorch +pip3 install torch torchvision torchaudio + +# Install fish-speech +pip3 install -e .[stable] + +# (Ubuntu / Debian User) Install sox +apt install libsox-dev +``` + +## Changelog + +- 2024/09/10: Updated Fish-Speech to 1.4 version, with an increase in dataset size and a change in the quantizer's n_groups from 4 to 8. +- 2024/07/02: Updated Fish-Speech to 1.2 version, remove VITS Decoder, and greatly enhanced zero-shot ability. +- 2024/05/10: Updated Fish-Speech to 1.1 version, implement VITS decoder to reduce WER and improve timbre similarity. +- 2024/04/22: Finished Fish-Speech 1.0 version, significantly modified VQGAN and LLAMA models. +- 2023/12/28: Added `lora` fine-tuning support. +- 2023/12/27: Add `gradient checkpointing`, `causual sampling`, and `flash-attn` support. +- 2023/12/19: Updated webui and HTTP API. +- 2023/12/18: Updated fine-tuning documentation and related examples. +- 2023/12/17: Updated `text2semantic` model, supporting phoneme-free mode. +- 2023/12/13: Beta version released, includes VQGAN model and a language model based on LLAMA (phoneme support only). + +## Acknowledgements + +- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) +- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) +- [GPT VITS](https://github.com/innnky/gpt-vits) +- [MQTTS](https://github.com/b04901014/MQTTS) +- [GPT Fast](https://github.com/pytorch-labs/gpt-fast) +- [Transformers](https://github.com/huggingface/transformers) +- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) diff --git a/docs/en/inference.md b/docs/en/inference.md new file mode 100644 index 0000000000000000000000000000000000000000..1eb3042f53f191d954a7e92b0cd1399ccb81f26c --- /dev/null +++ b/docs/en/inference.md @@ -0,0 +1,124 @@ +# Inference + +Inference support command line, HTTP API and web UI. + +!!! note + Overall, reasoning consists of several parts: + + 1. Encode a given ~10 seconds of voice using VQGAN. + 2. Input the encoded semantic tokens and the corresponding text into the language model as an example. + 3. Given a new piece of text, let the model generate the corresponding semantic tokens. + 4. Input the generated semantic tokens into VITS / VQGAN to decode and generate the corresponding voice. + +## Command Line Inference + +Download the required `vqgan` and `llama` models from our Hugging Face repository. + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +### 1. Generate prompt from voice: + +!!! note + If you plan to let the model randomly choose a voice timbre, you can skip this step. + +```bash +python tools/vqgan/inference.py \ + -i "paimon.wav" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +You should get a `fake.npy` file. + +### 2. Generate semantic tokens from text: + +```bash +python tools/llama/generate.py \ + --text "The text you want to convert" \ + --prompt-text "Your reference text" \ + --prompt-tokens "fake.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4" \ + --num-samples 2 \ + --compile +``` + +This command will create a `codes_N` file in the working directory, where N is an integer starting from 0. + +!!! note + You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second). + Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter. + +!!! info + For GPUs that do not support bf16, you may need to use the `--half` parameter. + +### 3. Generate vocals from semantic tokens: + +#### VQGAN Decoder + +```bash +python tools/vqgan/inference.py \ + -i "codes_0.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +## HTTP API Inference + +We provide a HTTP API for inference. You can use the following command to start the server: + +```bash +python -m tools.api \ + --listen 0.0.0.0:8080 \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +If you want to speed up inference, you can add the --compile parameter. + +After that, you can view and test the API at http://127.0.0.1:8080/. + +Below is an example of sending a request using `tools/post_api.py`. + +```bash +python -m tools.post_api \ + --text "Text to be input" \ + --reference_audio "Path to reference audio" \ + --reference_text "Text content of the reference audio" \ + --streaming True +``` + +The above command indicates synthesizing the desired audio according to the reference audio information and returning it in a streaming manner. + +The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command. + +```bash +python -m tools.post_api \ + --text "Text to input" \ + --reference_audio "reference audio path1" "reference audio path2" \ + --reference_text "reference audio text1" "reference audio text2"\ + --streaming False \ + --output "generated" \ + --format "mp3" +``` + +The above command synthesizes the desired `MP3` format audio based on the information from multiple reference audios and saves it as `generated.mp3` in the current directory. + +## GUI Inference +[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases/tag/v0.1.0) + +## WebUI Inference + +You can start the WebUI using the following command: + +```bash +python -m tools.webui \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +!!! note + You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI. + +Enjoy! diff --git a/docs/en/samples.md b/docs/en/samples.md new file mode 100644 index 0000000000000000000000000000000000000000..a079c0c3e29ff7a7e1bf1b1c9143903cb3394457 --- /dev/null +++ b/docs/en/samples.md @@ -0,0 +1,223 @@ +# Samples + +v1.2 samples are available on [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/). + +The following samples are from the v1.1 model. + +## Chinese Sentence 1 +``` +人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。 +流入她所注视的世间,也流入她如湖水般澄澈的目光。 +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Nahida (Genshin Impact)
Zhongli (Genshin Impact)
Furina (Genshin Impact)
Random Speaker 1 -
Random Speaker 2 -
+ + +## Chinese Sentence 2 +``` +你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么? +我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊? +你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊, +搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗? +一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊? +``` + + + + + + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Nahida (Genshin Impact)
Random Speaker -
+ + +## Chinese Sentence 3 +``` +大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练, +我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。 +作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。 +``` + + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Random Speaker -
+ +## English Sentence 1 + +``` +In the realm of advanced technology, the evolution of artificial intelligence stands as a +monumental achievement. This dynamic field, constantly pushing the boundaries of what +machines can do, has seen rapid growth and innovation. From deciphering complex data +patterns to driving cars autonomously, AI's applications are vast and diverse. +``` + + + + + + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Random Speaker 1 -
Random Speaker 2 -
+ +## English Sentence 2 +``` +Hello everyone, I am an open-source text-to-speech model developed by +Fish Audio. After training with 150,000 hours of data, I have become proficient +in Chinese, Japanese, and English, and my language processing abilities +are close to human level. My voice is capable of a wide range of expressions. +As a model with only hundreds of millions of parameters, I believe community +members can easily run and fine-tune me on their personal devices, allowing +me to serve as your personal voice assistant. +``` + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Random Speaker -
+ +## Japanese Sentence 1 + +``` +先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を +押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か +ら自動運転車の操縦まで、AIの応用は広範囲に及びます。 +``` + + + + + + + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Random Speaker 1 -
Random Speaker 2 -
+ +## Japanese Sentence 2 +``` +皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ +キストから音声への変換モデルです。15万時間のデータトレーニングを経て、 +中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。 +声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ +のメンバーが個人のデバイスで簡単に実行し、微調整することができると +信じています。これにより、私を個人の音声アシスタントとして活用できます。 +``` + + + + + + + + + + + + + + + + +
SpeakerInput AudioSynthesized Audio
Random Speaker -
diff --git a/docs/ja/finetune.md b/docs/ja/finetune.md new file mode 100644 index 0000000000000000000000000000000000000000..2a0381de8a5f83988cd058b3cef2a7982b443a94 --- /dev/null +++ b/docs/ja/finetune.md @@ -0,0 +1,125 @@ +# 微調整 + +明らかに、このページを開いたとき、few-shot 事前トレーニングモデルのパフォーマンスに満足していなかったことでしょう。データセット上でのパフォーマンスを向上させるためにモデルを微調整したいと考えています。 + +現在のバージョンでは、「LLAMA」部分のみを微調整する必要があります。 + +## LLAMAの微調整 +### 1. データセットの準備 + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 30.1-32.71.lab +│ └── 30.1-32.71.mp3 +└── SPK2 + ├── 38.79-40.85.lab + └── 38.79-40.85.mp3 +``` + +データセットを上記の形式に変換し、「data」ディレクトリに配置する必要があります。音声ファイルの拡張子は「.mp3」、「.wav」、または「.flac」にすることができ、注釈ファイルの拡張子は「.lab」にする必要があります。 + +!!! warning + データセットにラウドネス正規化を適用することをお勧めします。これを行うには、[fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。 + + ```bash + fap loudness-norm data-raw data --clean + ``` + + +### 2. セマンティックトークンのバッチ抽出 + +VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。 + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +次に、次のコマンドを実行してセマンティックトークンを抽出できます。 + +```bash +python tools/vqgan/extract_vq.py data \ + --num-workers 1 --batch-size 16 \ + --config-name "firefly_gan_vq" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +!!! note + `--num-workers` と `--batch-size` を調整して抽出速度を上げることができますが、GPUメモリの制限を超えないようにしてください。 + VITS形式の場合、`--filelist xxx.list` を使用してファイルリストを指定できます。 + +このコマンドは、`data`ディレクトリに`.npy`ファイルを作成します。以下のように表示されます。 + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 21.15-26.44.npy +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 27.51-29.98.npy +│ ├── 30.1-32.71.lab +│ ├── 30.1-32.71.mp3 +│ └── 30.1-32.71.npy +└── SPK2 + ├── 38.79-40.85.lab + ├── 38.79-40.85.mp3 + └── 38.79-40.85.npy +``` + +### 3. データセットをprotobufにパックする + +```bash +python tools/llama/build_dataset.py \ + --input "data" \ + --output "data/protos" \ + --text-extension .lab \ + --num-workers 16 +``` + +コマンドの実行が完了すると、`data`ディレクトリに`quantized-dataset-ft.protos`ファイルが表示されます。 + +### 4. 最後に、LoRAを使用して微調整する + +同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。 + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +最後に、次のコマンドを実行して微調整を開始できます。 + +```bash +python fish_speech/train.py --config-name text2semantic_finetune \ + project=$project \ + +lora@model.model.lora_config=r_8_alpha_16 +``` + +!!! note + `fish_speech/configs/text2semantic_finetune.yaml` を変更して、`batch_size`、`gradient_accumulation_steps` などのトレーニングパラメータを変更し、GPUメモリに適合させることができます。 + +!!! note + Windowsユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。 + +トレーニングが完了したら、[推論](inference.md)セクションを参照し、`--speaker SPK1` を使用して音声を生成します。 + +!!! info + デフォルトでは、モデルは話者の発話パターンのみを学習し、音色は学習しません。音色の安定性を確保するためにプロンプトを使用する必要があります。 + 音色を学習したい場合は、トレーニングステップ数を増やすことができますが、これにより過学習が発生する可能性があります。 + +トレーニングが完了したら、推論を行う前にLoRAの重みを通常の重みに変換する必要があります。 + +```bash +python tools/llama/merge_lora.py \ + --lora-config r_8_alpha_16 \ + --base-weight checkpoints/fish-speech-1.4 \ + --lora-weight results/$project/checkpoints/step_000000010.ckpt \ + --output checkpoints/fish-speech-1.4-yth-lora/ +``` +!!! note + 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。 diff --git a/docs/ja/index.md b/docs/ja/index.md new file mode 100644 index 0000000000000000000000000000000000000000..a7932f962ec454e7065de18a9be6d91748fef392 --- /dev/null +++ b/docs/ja/index.md @@ -0,0 +1,128 @@ +# Fish Speech の紹介 + +
+ +Discord + + +QQ + + +Docker + +
+ +!!! warning + 私たちは、コードベースの違法な使用について一切の責任を負いません。お住まいの地域の DMCA(デジタルミレニアム著作権法)およびその他の関連法を参照してください。
+ このコードベースとモデルは、CC-BY-NC-SA-4.0 ライセンス下でリリースされています。 + +

+ +

+ +## 要件 + +- GPU メモリ: 4GB(推論用)、8GB(ファインチューニング用) +- システム: Linux、Windows + +## Windows セットアップ + +Window にて開発を行っている方へ: 本コードベースを実行するのに WSL2 または Docker を利用することができます。 + +あまり詳しくない人は、Linux 環境なしでコードベースを実行するために以下の手順に従ってください。(モデルコンパイル機能`torch.compile`を利用できます。): + +
    +
  1. プロジェクトの圧縮ファイルをダウンロードし、展開
  2. +
  3. install_env.batを開いて実行に必要な環境を整えます。 + +
  4. +
  5. ステップ2でUSE_MIRROR=previewの場合、オプション、コンパイルモデル環境を有効にするたに以下のステップを実行してください。: +
      +
    1. 以下のリンクからLLVMコンパイラをダウンロードします: +
        +
      • LLVM-17.0.6(オリジナルサイト)
      • +
      • LLVM-17.0.6(ミラーサイト)
      • +
      • LLVM-17.0.6-win64.exeをダウンロードした後、ダブルクリックしてインストールし、適当な場所にインストールしてください。必ずAdd Path to Current Userをチェックして環境変数に追加することです。
      • +
      • インストールが完了したことを確認してください。
      • +
      +
    2. +
    3. Microsoft Visual C++ 再頒布可能パッケージをダウンロードしてインストールし、dllの欠落問題を解決します。 + +
    4. +
    5. Visual Studio Community Editionをダウンロードしてインストールし、MSVC++ビルドツールを取得し、LLVMのヘッダーファイル依存関係を解決します。 +
        +
      • Visual Studio ダウンロード
      • +
      • Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロードします。
      • +
      • 以下のスクリーンショットのようにModifyボタンをクリックし、Desktop development with C++オプションにチェックをつけてダウンロードします。
      • +

        + +

        +
      +
    6. +
    7. インストール CUDA Toolkit 12
    8. +
    +
  6. +
  7. start.batを実行し、Fish-Speechのトレーニング/推論設定WebUIを開いてください。。 + +
  8. +
  9. (オプション)run_cmd.batをダブルクリックして、このプロジェクトの仮想環境を有効化できます。
  10. +
+ +## Linux セットアップ + +```bash +# python 3.10の仮想環境を作成します。virtualenvも使用できます。 +conda create -n fish-speech python=3.10 +conda activate fish-speech + +# pytorchをインストールします。 +pip3 install torch torchvision torchaudio + +# fish-speechをインストールします。 +pip3 install -e .[stable] + +# (Ubuntu / Debianユーザー) soxをインストールします。 +apt install libsox-dev +``` + +## 変更履歴 + +- 2024/07/02: Fish-Speech を Ver.1.2 に更新し、VITS デコーダーを削除し、ゼロショット能力を大幅に強化しました。 +- 2024/05/10: Fish-Speech を Ver.1.1 に更新し、VITS デコーダーを実装して WER を減少させ、音色の類似性を向上させました。 +- 2024/04/22: Fish-Speech Ver.1.0 を完成させ、VQGAN および LLAMA モデルを大幅に修正しました。 +- 2023/12/28: `lora`微調整サポートを追加しました。 +- 2023/12/27: `gradient checkpointing`、`causual sampling`、および`flash-attn`サポートを追加しました。 +- 2023/12/19: webui および HTTP API を更新しました。 +- 2023/12/18: 微調整ドキュメントおよび関連例を更新しました。 +- 2023/12/17: `text2semantic`モデルを更新し、自由音素モードをサポートしました。 +- 2023/12/13: ベータ版をリリースし、VQGAN モデルおよび LLAMA に基づく言語モデル(音素のみサポート)を含みます。 + +## 謝辞 + +- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) +- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) +- [GPT VITS](https://github.com/innnky/gpt-vits) +- [MQTTS](https://github.com/b04901014/MQTTS) +- [GPT Fast](https://github.com/pytorch-labs/gpt-fast) +- [Transformers](https://github.com/huggingface/transformers) +- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) diff --git a/docs/ja/inference.md b/docs/ja/inference.md new file mode 100644 index 0000000000000000000000000000000000000000..4ca5576835258571e2e58506e43bd312b1cf5b1a --- /dev/null +++ b/docs/ja/inference.md @@ -0,0 +1,157 @@ +# 推論 + +推論は、コマンドライン、HTTP API、および Web UI をサポートしています。 + +!!! note + 全体として、推論は次のいくつかの部分で構成されています: + + 1. VQGANを使用して、与えられた約10秒の音声をエンコードします。 + 2. エンコードされたセマンティックトークンと対応するテキストを例として言語モデルに入力します。 + 3. 新しいテキストが与えられた場合、モデルに対応するセマンティックトークンを生成させます。 + 4. 生成されたセマンティックトークンをVITS / VQGANに入力してデコードし、対応する音声を生成します。 + +## コマンドライン推論 + +必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。 + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +### 1. 音声からプロンプトを生成する: + +!!! note + モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。 + +```bash +python tools/vqgan/inference.py \ + -i "paimon.wav" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +`fake.npy`ファイルが生成されるはずです。 + +### 2. テキストからセマンティックトークンを生成する: + +```bash +python tools/llama/generate.py \ + --text "変換したいテキスト" \ + --prompt-text "参照テキスト" \ + --prompt-tokens "fake.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4" \ + --num-samples 2 \ + --compile +``` + +このコマンドは、作業ディレクトリに`codes_N`ファイルを作成します。ここで、N は 0 から始まる整数です。 + +!!! note + `--compile`を使用して CUDA カーネルを融合し、より高速な推論を実現することができます(約 30 トークン/秒 -> 約 500 トークン/秒)。 + それに対応して、加速を使用しない場合は、`--compile`パラメータをコメントアウトできます。 + +!!! info + bf16 をサポートしていない GPU の場合、`--half`パラメータを使用する必要があるかもしれません。 + +### 3. セマンティックトークンから音声を生成する: + +#### VQGAN デコーダー + +```bash +python tools/vqgan/inference.py \ + -i "codes_0.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +## HTTP API 推論 + +推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます: + +```bash +python -m tools.api \ + --listen 0.0.0.0:8080 \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +推論を高速化したい場合は、--compile パラメータを追加できます。 + +その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。 + +以下は、`tools/post_api.py` を使用してリクエストを送信する例です。 + +```bash +python -m tools.post_api \ + --text "入力するテキスト" \ + --reference_audio "参照音声へのパス" \ + --reference_text "参照音声テキスト" \ + --streaming True +``` + +上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。 + +`{SPEAKER}`と`{EMOTION}`に基づいて参照音声をランダムに選択する必要がある場合は、以下の手順に従って設定します: + +### 1. プロジェクトのルートディレクトリに`ref_data`フォルダを作成します。 + +### 2. `ref_data`フォルダ内に次のような構造のディレクトリを作成します。 + +``` +. +├── SPEAKER1 +│ ├──EMOTION1 +│ │ ├── 21.15-26.44.lab +│ │ ├── 21.15-26.44.wav +│ │ ├── 27.51-29.98.lab +│ │ ├── 27.51-29.98.wav +│ │ ├── 30.1-32.71.lab +│ │ └── 30.1-32.71.flac +│ └──EMOTION2 +│ ├── 30.1-32.71.lab +│ └── 30.1-32.71.mp3 +└── SPEAKER2 + └─── EMOTION3 + ├── 30.1-32.71.lab + └── 30.1-32.71.mp3 + +``` + +つまり、まず`ref_data`に`{SPEAKER}`フォルダを配置し、各スピーカーの下に`{EMOTION}`フォルダを配置し、各感情フォルダの下に任意の数の音声-テキストペアを配置します + +### 3. 仮想環境で以下のコマンドを入力します. + +```bash +python tools/gen_ref.py + +``` + +参照ディレクトリを生成します。 + +### 4. API を呼び出します。 + +```bash +python -m tools.post_api \ + --text "入力するテキスト" \ + --speaker "${SPEAKER1}" \ + --emotion "${EMOTION1}" \ + --streaming True + +``` + +上記の例はテスト目的のみです。 + +## WebUI 推論 + +次のコマンドを使用して WebUI を起動できます: + +```bash +python -m tools.webui \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +!!! note + Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。 + +お楽しみください! diff --git a/docs/ja/samples.md b/docs/ja/samples.md new file mode 100644 index 0000000000000000000000000000000000000000..f94e83d02d7f1d670cbeb75499c1abcfef71bac1 --- /dev/null +++ b/docs/ja/samples.md @@ -0,0 +1,223 @@ +# サンプル + +v1.2のサンプルは[Bilibili](https://www.bilibili.com/video/BV1wz421B71D/)で利用可能です。 + +以下のサンプルはv1.1モデルからのものです。 + +## 中国語の文1 +``` +人間灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。 +流入她所注视的世间,也流入她如湖水般澄澈的目光。 +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ナヒーダ (原神)
鍾離 (原神)
フリナ (原神)
ランダム話者1 -
ランダム話者2 -
+ + +## 中国語の文2 +``` +你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么? +我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊? +你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊, +搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗? +一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊? +``` + + + + + + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ナヒーダ (原神)
ランダム話者 -
+ + +## 中国語の文3 +``` +大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练, +我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。 +作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。 +``` + + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ランダム話者 -
+ +## 英語の文1 + +``` +In the realm of advanced technology, the evolution of artificial intelligence stands as a +monumental achievement. This dynamic field, constantly pushing the boundaries of what +machines can do, has seen rapid growth and innovation. From deciphering complex data +patterns to driving cars autonomously, AI's applications are vast and diverse. +``` + + + + + + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ランダム話者1 -
ランダム話者2 -
+ +## 英語の文2 +``` +Hello everyone, I am an open-source text-to-speech model developed by +Fish Audio. After training with 150,000 hours of data, I have become proficient +in Chinese, Japanese, and English, and my language processing abilities +are close to human level. My voice is capable of a wide range of expressions. +As a model with only hundreds of millions of parameters, I believe community +members can easily run and fine-tune me on their personal devices, allowing +me to serve as your personal voice assistant. +``` + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ランダム話者 -
+ +## 日本語の文1 + +``` +先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を +押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か +ら自動運転車の操縦まで、AIの応用は広範囲に及びます。 +``` + + + + + + + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ランダム話者1 -
ランダム話者2 -
+ +## 日本語の文2 +``` +皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ +キストから音声への変換モデルです。15万時間のデータトレーニングを経て、 +中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。 +声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ +のメンバーが個人のデバイスで簡単に実行し、微調整することができると +信じています。これにより、私を個人の音声アシスタントとして活用できます。 +``` + + + + + + + + + + + + + + + + +
話者入力音声合成音声
ランダム話者 -
diff --git a/docs/pt/finetune.md b/docs/pt/finetune.md new file mode 100644 index 0000000000000000000000000000000000000000..a8e52aec24da2ddfb1025d152004f4fe32588616 --- /dev/null +++ b/docs/pt/finetune.md @@ -0,0 +1,125 @@ +# Ajuste Fino + +É óbvio que ao abrir esta página, você não deve estar muito satisfeito com o desempenho do modelo pré-treinado com poucos exemplos. Você pode querer ajustar o modelo para melhorar seu desempenho em seu conjunto de dados. + +Na atual versão, a única coisa que você precisa ajustar é a parte do 'LLAMA'. + +## Ajuste Fino do LLAMA +### 1. Preparando o conjunto de dados + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 30.1-32.71.lab +│ └── 30.1-32.71.mp3 +└── SPK2 + ├── 38.79-40.85.lab + └── 38.79-40.85.mp3 +``` + +Você precisa converter seu conjunto de dados para o formato acima e colocá-lo em `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`. + +!!! warning + É recomendado aplicar normalização de volume ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso. + + ```bash + fap loudness-norm data-raw data --clean + ``` + + +### 2. Extração em lote de tokens semânticos + +Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando: + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos: + +```bash +python tools/vqgan/extract_vq.py data \ + --num-workers 1 --batch-size 16 \ + --config-name "firefly_gan_vq" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +!!! note + Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.   + Para o formato VITS, você pode especificar uma lista de arquivos usando `--filelist xxx.list`. + +Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo: + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 21.15-26.44.npy +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 27.51-29.98.npy +│ ├── 30.1-32.71.lab +│ ├── 30.1-32.71.mp3 +│ └── 30.1-32.71.npy +└── SPK2 + ├── 38.79-40.85.lab + ├── 38.79-40.85.mp3 + └── 38.79-40.85.npy +``` + +### 3. Empacotar o conjunto de dados em protobuf + +```bash +python tools/llama/build_dataset.py \ + --input "data" \ + --output "data/protos" \ + --text-extension .lab \ + --num-workers 16 +``` + +Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.protos` no diretório `data`. + +### 4. E finalmente, chegamos ao ajuste fino com LoRA + +Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando: + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +E então, execute o seguinte comando para iniciar o ajuste fino: + +```bash +python fish_speech/train.py --config-name text2semantic_finetune \ + project=$project \ + +lora@model.model.lora_config=r_8_alpha_16 +``` + +!!! note + Se quiser, você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se ajustar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`. + +!!! note + Para usuários do Windows, é recomendado usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`. + +Após concluir o treinamento, consulte a seção [inferência](inference.md), e use `--speaker SPK1` para gerar fala. + +!!! info + Por padrão, o modelo aprenderá apenas os padrões de fala do orador e não o timbre. Ainda pode ser preciso usar prompts para garantir a estabilidade do timbre. + Se quiser que ele aprenda o timbre, aumente o número de etapas de treinamento, mas isso pode levar ao overfitting (sobreajuste). + +Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares antes de realizar a inferência. + +```bash +python tools/llama/merge_lora.py \ + --lora-config r_8_alpha_16 \ + --base-weight checkpoints/fish-speech-1.4 \ + --lora-weight results/$project/checkpoints/step_000000010.ckpt \ + --output checkpoints/fish-speech-1.4-yth-lora/ +``` +!!! note + É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD). diff --git a/docs/pt/index.md b/docs/pt/index.md new file mode 100644 index 0000000000000000000000000000000000000000..7d2d7ef0baccce0d8385241820a445ea0ef1ebbe --- /dev/null +++ b/docs/pt/index.md @@ -0,0 +1,131 @@ +# Introdução + +
+ +Discord + + +QQ + + +Docker + +
+ +!!! warning + Não nos responsabilizamos por qualquer uso ilegal do código-fonte. Consulte as leis locais sobre DMCA (Digital Millennium Copyright Act) e outras leis relevantes em sua região.
+ Este repositório de código e os modelos são distribuídos sob a licença CC-BY-NC-SA-4.0. + +

+ +

+ +## Requisitos + +- Memória da GPU: 4GB (para inferência), 8GB (para ajuste fino) +- Sistema: Linux, Windows + +## Configuração para Windows + +No Windows, usuários avançados podem considerar usar o WSL2 ou Docker para executar o código. + +Para Usuários comuns (não-avançados), siga os métodos abaixo para executar o código sem um ambiente Linux (incluindo suporte para `torch.compile`): + +
    +
  1. Extraia o arquivo compactado do projeto.
  2. +
  3. Prepare o ambiente conda: + +
  4. +
  5. Se você escolheu a versão de prévia com ambiente compilado (INSTALL_TYPE=preview), siga para a próxima etapa (opcional): +
      +
    1. Baixe o compilador LLVM usando os seguintes links: + +
    2. +
    3. Baixe e instale o pacote Microsoft Visual C++ Redistributable para resolver possíveis problemas de .dll ausentes. + +
    4. +
    5. Baixe e instale o Visual Studio Community Edition para obter as ferramentas de compilação MSVC++, resolvendo as dependências do arquivo de cabeçalho LLVM. +
        +
      • Download do Visual Studio
      • +
      • Após instalar o Visual Studio Installer, baixe o Visual Studio Community 2022.
      • +
      • Clique no botão Modificar, conforme mostrado abaixo, encontre a opção Desenvolvimento para desktop com C++ e marque-a para download.
      • +

        + +

        +
      +
    6. +
    7. Instale o CUDA Toolkit 12
    8. +
    +
  6. +
  7. Clique duas vezes em start.bat para entrar na página da WebUI de configuração de inferência de treinamento do Fish-Speech. + +
  8. +
  9. (Opcional) Clique duas vezes em run_cmd.bat para entrar na CLI do conda/python deste projeto.
  10. +
+ +## Configuração para Linux + +```bash +# Crie um ambiente virtual python 3.10, você também pode usar virtualenv +conda create -n fish-speech python=3.10 +conda activate fish-speech + +# Instale o pytorch +pip3 install torch torchvision torchaudio + +# Instale o fish-speech +pip3 install -e .[stable] + +# Para os Usuário do Ubuntu / Debian: Instale o sox +apt install libsox-dev +``` + +## Histórico de Alterações + +- 02/07/2024: Fish-Speech atualizado para a versão 1.2, removido o Decodificador VITS e aprimorado consideravelmente a capacidade de zero-shot. +- 10/05/2024: Fish-Speech atualizado para a versão 1.1, implementado o decodificador VITS para reduzir a WER e melhorar a similaridade de timbre. +- 22/04/2024: Finalizada a versão 1.0 do Fish-Speech, modificados significativamente os modelos VQGAN e LLAMA. +- 28/12/2023: Adicionado suporte para ajuste fino `lora`. +- 27/12/2023: Adicionado suporte para `gradient checkpointing`, `causual sampling` e `flash-attn`. +- 19/12/2023: Atualizada a interface web e a API HTTP. +- 18/12/2023: Atualizada a documentação de ajuste fino e exemplos relacionados. +- 17/12/2023: Atualizado o modelo `text2semantic`, suportando o modo sem fonemas. +- 13/12/2023: Versão beta lançada, incluindo o modelo VQGAN e um modelo de linguagem baseado em LLAMA (suporte apenas a fonemas). + +## Agradecimentos + +- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) +- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) +- [GPT VITS](https://github.com/innnky/gpt-vits) +- [MQTTS](https://github.com/b04901014/MQTTS) +- [GPT Fast](https://github.com/pytorch-labs/gpt-fast) +- [Transformers](https://github.com/huggingface/transformers) +- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) diff --git a/docs/pt/inference.md b/docs/pt/inference.md new file mode 100644 index 0000000000000000000000000000000000000000..6a8ff5c61f44bc141f1e4b5802077e91fbc6eaef --- /dev/null +++ b/docs/pt/inference.md @@ -0,0 +1,153 @@ +# Inferência + +Suporte para inferência por linha de comando, API HTTP e interface web (WebUI). + +!!! note + O processo de raciocínio, em geral, consiste em várias partes: + + 1. Codificar cerca de 10 segundos de voz usando VQGAN. + 2. Inserir os tokens semânticos codificados e o texto correspondente no modelo de linguagem como um exemplo. + 3. Dado um novo trecho de texto, fazer com que o modelo gere os tokens semânticos correspondentes. + 4. Inserir os tokens semânticos gerados no VITS / VQGAN para decodificar e gerar a voz correspondente. + +## Inferência por Linha de Comando + +Baixe os modelos `vqgan` e `llama` necessários do nosso repositório Hugging Face. + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +### 1. Gerar prompt a partir da voz: + +!!! note + Se quiser permitir que o modelo escolha aleatoriamente um timbre de voz, pule esta etapa. + +```bash +python tools/vqgan/inference.py \ + -i "paimon.wav" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +Você deverá obter um arquivo `fake.npy`. + +### 2. Gerar tokens semânticos a partir do texto: + +```bash +python tools/llama/generate.py \ + --text "O texto que você deseja converter" \ + --prompt-text "Seu texto de referência" \ + --prompt-tokens "fake.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4" \ + --num-samples 2 \ + --compile +``` + +Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é um número inteiro começando de 0. + +!!! note + Use `--compile` para fundir kernels CUDA para ter uma inferência mais rápida (~30 tokens/segundo -> ~500 tokens/segundo). + Mas, se não planeja usar a aceleração CUDA, comente o parâmetro `--compile`. + +!!! info + Para GPUs que não suportam bf16, pode ser necessário usar o parâmetro `--half`. + +### 3. Gerar vocais a partir de tokens semânticos: + +#### Decodificador VQGAN + +```bash +python tools/vqgan/inference.py \ + -i "codes_0.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +## Inferência por API HTTP + +Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor: + +```bash +python -m tools.api \ + --listen 0.0.0.0:8080 \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +Para acelerar a inferência, adicione o parâmetro `--compile`. + +Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/. + +Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`. + +```bash +python -m tools.post_api \ + --text "Texto a ser inserido" \ + --reference_audio "Caminho para o áudio de referência" \ + --reference_text "Conteúdo de texto do áudio de referência" \ + --streaming True +``` + +O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming. + +Caso selecione, de forma aleatória, o áudio de referência com base em `{SPEAKER}` e `{EMOTION}`, o configure de acordo com as seguintes etapas: + +### 1. Crie uma pasta `ref_data` no diretório raiz do projeto. + +### 2. Crie uma estrutura de diretórios semelhante à seguinte dentro da pasta `ref_data`. + +``` +. +├── SPEAKER1 +│ ├──EMOTION1 +│ │ ├── 21.15-26.44.lab +│ │ ├── 21.15-26.44.wav +│ │ ├── 27.51-29.98.lab +│ │ ├── 27.51-29.98.wav +│ │ ├── 30.1-32.71.lab +│ │ └── 30.1-32.71.flac +│ └──EMOTION2 +│ ├── 30.1-32.71.lab +│ └── 30.1-32.71.mp3 +└── SPEAKER2 + └─── EMOTION3 + ├── 30.1-32.71.lab + └── 30.1-32.71.mp3 +``` + +Ou seja, primeiro coloque as pastas `{SPEAKER}` em `ref_data`, depois coloque as pastas `{EMOTION}` em cada pasta de orador (speaker) e coloque qualquer número de `pares áudio-texto` em cada pasta de emoção. + +### 3. Digite o seguinte comando no ambiente virtual + +```bash +python tools/gen_ref.py + +``` + +### 4. Chame a API. + +```bash +python -m tools.post_api \ + --text "Texto a ser inserido" \ + --speaker "${SPEAKER1}" \ + --emotion "${EMOTION1}" \ + --streaming True +``` + +O exemplo acima é apenas para fins de teste. + +## Inferência por WebUI + +Para iniciar a WebUI de Inferência execute o seguinte comando: + +```bash +python -m tools.webui \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +!!! note + É possível usar variáveis de ambiente do Gradio, como `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`, para configurar a WebUI. + +Divirta-se! diff --git a/docs/pt/samples.md b/docs/pt/samples.md new file mode 100644 index 0000000000000000000000000000000000000000..75a1669d06ea1da4e138856cc15b8386167543fa --- /dev/null +++ b/docs/pt/samples.md @@ -0,0 +1,223 @@ +# Amostras + +As amostras da v1.2 estão disponíveis em [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/). + +As seguintes amostras são do modelo v1.1. + +## Frase em Chinês 1 +``` +人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。 +流入她所注视的世间,也流入她如湖水般澄澈的目光。 +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Nahida (Genshin Impact)
Zhongli (Genshin Impact)
Furina (Genshin Impact)
Orador Aleatório 1 -
Orador Aleatório 2 -
+ + +## Frase em Chinês 2 +``` +你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么? +我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊? +你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊, +搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗? +一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊? +``` + + + + + + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Nahida (Genshin Impact)
Orador Aleatório -
+ + +## Frase em Chinês 3 +``` +大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练, +我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。 +作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。 +``` + + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Orador Aleatório -
+ +## Frase em Inglês 1 + +``` +In the realm of advanced technology, the evolution of artificial intelligence stands as a +monumental achievement. This dynamic field, constantly pushing the boundaries of what +machines can do, has seen rapid growth and innovation. From deciphering complex data +patterns to driving cars autonomously, AI's applications are vast and diverse. +``` + + + + + + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Orador Aleatório 1 -
Orador Aleatório 2 -
+ +## Frase em Inglês 2 +``` +Hello everyone, I am an open-source text-to-speech model developed by +Fish Audio. After training with 150,000 hours of data, I have become proficient +in Chinese, Japanese, and English, and my language processing abilities +are close to human level. My voice is capable of a wide range of expressions. +As a model with only hundreds of millions of parameters, I believe community +members can easily run and fine-tune me on their personal devices, allowing +me to serve as your personal voice assistant. +``` + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Orador Aleatório -
+ +## Frase em Japonês 1 + +``` +先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を +押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か +ら自動運転車の操縦まで、AIの応用は広範囲に及びます。 +``` + + + + + + + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Orador Aleatório 1 -
Orador Aleatório 2 -
+ +## Frase em Japonês 2 +``` +皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ +キストから音声への変換モデルです。15万時間のデータトレーニングを経て、 +中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。 +声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ +のメンバーが個人のデバイスで簡単に実行し、微調整することができると +信じています。これにより、私を個人の音声アシスタントとして活用できます。 +``` + + + + + + + + + + + + + + + + +
OradorÁudio de EntradaÁudio Sintetizado
Orador Aleatório -
diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6e145dbea1b9b26b2bddd7500e3f270b3eb0009 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,3 @@ +mkdocs-material +mkdocs-static-i18n[material] +mkdocs[i18n] diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css new file mode 100644 index 0000000000000000000000000000000000000000..a88af87b3cdbfd2d6b05f39877d5821bb7ebe119 --- /dev/null +++ b/docs/stylesheets/extra.css @@ -0,0 +1,3 @@ +.md-grid { + max-width: 1440px; +} diff --git a/docs/zh/finetune.md b/docs/zh/finetune.md new file mode 100644 index 0000000000000000000000000000000000000000..c72722cd76b0ce3b00adcf88db23c0ccfcb35da1 --- /dev/null +++ b/docs/zh/finetune.md @@ -0,0 +1,136 @@ +# 微调 + +显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好. + +在目前版本,你只需要微调'LLAMA'部分即可. + +## LLAMA 微调 +### 1. 准备数据集 + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 30.1-32.71.lab +│ └── 30.1-32.71.mp3 +└── SPK2 + ├── 38.79-40.85.lab + └── 38.79-40.85.mp3 +``` + +你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`. + +!!! warning + 建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤. + ```bash + fap loudness-norm data-raw data --clean + ``` + +### 2. 批量提取语义 token + +确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令: + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +对于中国大陆用户, 可使用 mirror 下载. + +```bash +HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +随后可运行以下命令来提取语义 token: + +```bash +python tools/vqgan/extract_vq.py data \ + --num-workers 1 --batch-size 16 \ + --config-name "firefly_gan_vq" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +!!! note + 你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制. + +该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示: + +``` +. +├── SPK1 +│ ├── 21.15-26.44.lab +│ ├── 21.15-26.44.mp3 +│ ├── 21.15-26.44.npy +│ ├── 27.51-29.98.lab +│ ├── 27.51-29.98.mp3 +│ ├── 27.51-29.98.npy +│ ├── 30.1-32.71.lab +│ ├── 30.1-32.71.mp3 +│ └── 30.1-32.71.npy +└── SPK2 + ├── 38.79-40.85.lab + ├── 38.79-40.85.mp3 + └── 38.79-40.85.npy +``` + +### 3. 打包数据集为 protobuf + +```bash +python tools/llama/build_dataset.py \ + --input "data" \ + --output "data/protos" \ + --text-extension .lab \ + --num-workers 16 +``` + +命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件. + + +### 4. 最后, 使用 LoRA 进行微调 + +同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令: + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +对于中国大陆用户, 可使用 mirror 下载. + +```bash +HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +最后, 你可以运行以下命令来启动微调: + +```bash +python fish_speech/train.py --config-name text2semantic_finetune \ + project=$project \ + +lora@model.model.lora_config=r_8_alpha_16 +``` + +!!! note + 你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存. + +!!! note + 对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题. + +训练结束后, 你可以参考 [推理](inference.md) 部分, 并携带 `--speaker SPK1` 参数来测试你的模型. + +!!! info + 默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性. + 如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合. + +训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理. + +```bash +python tools/llama/merge_lora.py \ + --lora-config r_8_alpha_16 \ + --base-weight checkpoints/fish-speech-1.4 \ + --lora-weight results/$project/checkpoints/step_000000010.ckpt \ + --output checkpoints/fish-speech-1.4-yth-lora/ +``` + +!!! note + 你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好. diff --git a/docs/zh/index.md b/docs/zh/index.md new file mode 100644 index 0000000000000000000000000000000000000000..b2c9b79c69ff521664552829610369da28b54127 --- /dev/null +++ b/docs/zh/index.md @@ -0,0 +1,191 @@ +# 介绍 + +
+ +Discord + + +QQ + + +Docker + +
+ +!!! warning + 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+ 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布. + +

+ +

+ +## 要求 + +- GPU 内存: 4GB (用于推理), 8GB (用于微调) +- 系统: Linux, Windows + +## Windows 配置 + +Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。 + +```bash +# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv +conda create -n fish-speech python=3.10 +conda activate fish-speech + +# 安装 pytorch +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + +# 安装 fish-speech +pip3 install -e . + +# (开启编译加速) 安装 triton-windows +pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl +``` + +Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`): + +1. 解压项目压缩包。 +2. 点击 `install_env.bat` 安装环境。 +3. 若需要开启编译加速则执行这一步: + 1. 使用如下链接下载 LLVM 编译器。 + - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) + - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) + - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。 + - 确认安装完成。 + 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。 + - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe) + 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。 + - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/) + - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022 + - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载 + 4. 下载安装 [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64) +4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`. + +!!! info "可选" + + 想启动 推理 WebUI 界面?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式: + ``` + --infer + # --api + # --listen ... + ... + ``` + +!!! info "可选" + + 想启动 API 服务器?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式: + ``` + # --infer + --api + --listen ... + ... + ``` + +!!! info "可选" + + 双击 `run_cmd.bat` 进入本项目的 conda/python 命令行环境 + +## Linux 配置 + +```bash +# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv +conda create -n fish-speech python=3.10 +conda activate fish-speech + +# 安装 pytorch +pip3 install torch torchvision torchaudio + +# 安装 fish-speech +pip3 install -e .[stable] + +# (Ubuntu / Debian 用户) 安装 sox +apt install libsox-dev +``` + +## Docker 配置 + +1. 安装 NVIDIA Container Toolkit: + + Docker 如果想使用 GPU 进行模型训练和推理,需要安装 NVIDIA Container Toolkit : + + 对于 Ubuntu 用户: + + ```bash + # 添加远程仓库 + curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ + && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list + # 安装 nvidia-container-toolkit + sudo apt-get update + sudo apt-get install -y nvidia-container-toolkit + # 重启 Docker 服务 + sudo systemctl restart docker + ``` + + 对于使用其他 Linux 发行版的用户,安装指南请参考:[NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)。 + + 注:对于中国大陆的用户,您可能需要使用代理来完成相关工具的安装。 + +2. 拉取并运行 fish-speech 镜像 + + ```shell + # 拉取镜像 + docker pull fishaudio/fish-speech + # 运行镜像 + docker run -it \ + --name fish-speech \ + --gpus all \ + -p 7860:7860 \ + fishaudio/fish-speech \ + zsh + # 如果需要使用其他端口,请修改 -p 参数为 YourPort:7860 + ``` + +3. 下载模型依赖 + + 确保您在 docker 容器内的终端,然后再从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。 + + ```bash + huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + ``` + + 对于中国大陆用户,可以通过镜像站下载。 + + ```bash + HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + ``` + +4. 配置环境变量,访问 WebUI + + 在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。 + 接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。 + + 如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。 + + 如果是部署在服务器上,更换 localhost 为您的服务器 ip 即可。 + +## 更新日志 + +- 2024/09/10: 更新了 Fish-Speech 到 1.4, 增加了数据集大小, quantizer n_groups 4 -> 8. +- 2024/07/02: 更新了 Fish-Speech 到 1.2 版本,移除 VITS Decoder,同时极大幅度提升 zero-shot 能力. +- 2024/05/10: 更新了 Fish-Speech 到 1.1 版本,引入了 VITS Decoder 来降低口胡和提高音色相似度. +- 2024/04/22: 完成了 Fish-Speech 1.0 版本, 大幅修改了 VQGAN 和 LLAMA 模型. +- 2023/12/28: 添加了 `lora` 微调支持. +- 2023/12/27: 添加了 `gradient checkpointing`, `causual sampling` 和 `flash-attn` 支持. +- 2023/12/19: 更新了 Webui 和 HTTP API. +- 2023/12/18: 更新了微调文档和相关例子. +- 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式. +- 2023/12/13: 测试版发布, 包含 VQGAN 模型和一个基于 LLAMA 的语言模型 (只支持音素). + +## 致谢 + +- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) +- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) +- [GPT VITS](https://github.com/innnky/gpt-vits) +- [MQTTS](https://github.com/b04901014/MQTTS) +- [GPT Fast](https://github.com/pytorch-labs/gpt-fast) +- [Transformers](https://github.com/huggingface/transformers) +- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) diff --git a/docs/zh/inference.md b/docs/zh/inference.md new file mode 100644 index 0000000000000000000000000000000000000000..f783a525bc625a02db90bc9b90a1a87351e76e07 --- /dev/null +++ b/docs/zh/inference.md @@ -0,0 +1,134 @@ +# 推理 + +推理支持命令行, http api, 以及 webui 三种方式. + +!!! note + 总的来说, 推理分为几个部分: + + 1. 给定一段 ~10 秒的语音, 将它用 VQGAN 编码. + 2. 将编码后的语义 token 和对应文本输入语言模型作为例子. + 3. 给定一段新文本, 让模型生成对应的语义 token. + 4. 将生成的语义 token 输入 VQGAN 解码, 生成对应的语音. + +## 命令行推理 + +从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。 + +```bash +huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +对于中国大陆用户,可使用 mirror 下载。 + +```bash +HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +``` + +### 1. 从语音生成 prompt: + +!!! note + 如果你打算让模型随机选择音色, 你可以跳过这一步. + +```bash +python tools/vqgan/inference.py \ + -i "paimon.wav" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +你应该能得到一个 `fake.npy` 文件. + +### 2. 从文本生成语义 token: + +```bash +python tools/llama/generate.py \ + --text "要转换的文本" \ + --prompt-text "你的参考文本" \ + --prompt-tokens "fake.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4" \ + --num-samples 2 \ + --compile +``` + +该命令会在工作目录下创建 `codes_N` 文件, 其中 N 是从 0 开始的整数. + +!!! note + 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒). + 对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数. + +!!! info + 对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数. + +### 3. 从语义 token 生成人声: + +#### VQGAN 解码 + +```bash +python tools/vqgan/inference.py \ + -i "codes_0.npy" \ + --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" +``` + +## HTTP API 推理 + +运行以下命令来启动 HTTP 服务: + +```bash +python -m tools.api \ + --listen 0.0.0.0:8080 \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` +如果你想要加速推理,可以加上`--compile`参数。 + +推荐中国大陆用户运行以下命令来启动 HTTP 服务: +```bash +HF_ENDPOINT=https://hf-mirror.com python -m ...(同上) +``` + +随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API. + +下面是使用`tools/post_api.py`发送请求的示例。 + +```bash +python -m tools.post_api \ + --text "要输入的文本" \ + --reference_audio "参考音频路径" \ + --reference_text "参考音频的文本内容" \ + --streaming True +``` + +上面的命令表示按照参考音频的信息,合成所需的音频并流式返回. + +下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。 + +```bash +python -m tools.post_api \ + --text "要输入的文本" \ + --reference_audio "参考音频路径1" "参考音频路径2" \ + --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\ + --streaming False \ + --output "generated" \ + --format "mp3" +``` + +上面的命令表示按照多个参考音频的信息,合成所需的`MP3`格式音频,并保存为当前目录的`generated.mp3`文件。 + +## GUI 推理 +[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases/tag/v0.1.0) + +## WebUI 推理 + +你可以使用以下命令来启动 WebUI: + +```bash +python -m tools.webui \ + --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --decoder-config-name firefly_gan_vq +``` + +!!! note + 你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI. + +祝大家玩得开心! diff --git a/docs/zh/samples.md b/docs/zh/samples.md new file mode 100644 index 0000000000000000000000000000000000000000..b4d0fab1d801ce6c55916e7a6f0a261ec4373849 --- /dev/null +++ b/docs/zh/samples.md @@ -0,0 +1,223 @@ +# 例子 + +v1.2 的样本可以在 [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/) 观看。 + +以下样本来自 v1.1 版本的模型。 + +## 中文句子 1 +``` +人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。 +流入她所注视的世间,也流入她如湖水般澄澈的目光。 +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
纳西妲 (原神)
钟离 (原神)
芙宁娜 (原神)
随机说话人 1 -
随机说话人 2 -
+ + +## 中文句子 2 +``` +你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么? +我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊? +你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊, +搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗? +一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊? +``` + + + + + + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
纳西妲 (原神)
随机说话人 -
+ + +## 中文句子 3 +``` +大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练, +我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。 +作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。 +``` + + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
随机说话人 -
+ +## 英文句子 1 + +``` +In the realm of advanced technology, the evolution of artificial intelligence stands as a +monumental achievement. This dynamic field, constantly pushing the boundaries of what +machines can do, has seen rapid growth and innovation. From deciphering complex data +patterns to driving cars autonomously, AI's applications are vast and diverse. +``` + + + + + + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
随机说话人 1 -
随机说话人 2 -
+ +## 英文句子 2 +``` +Hello everyone, I am an open-source text-to-speech model developed by +Fish Audio. After training with 150,000 hours of data, I have become proficient +in Chinese, Japanese, and English, and my language processing abilities +are close to human level. My voice is capable of a wide range of expressions. +As a model with only hundreds of millions of parameters, I believe community +members can easily run and fine-tune me on their personal devices, allowing +me to serve as your personal voice assistant. +``` + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
随机说话人 -
+ +## 日文句子 1 + +``` +先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を +押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か +ら自動運転車の操縦まで、AIの応用は広範囲に及びます。 +``` + + + + + + + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
随机说话人 1 -
随机说话人 2 -
+ +## 日文句子 2 +``` +皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ +キストから音声への変換モデルです。15万時間のデータトレーニングを経て、 +中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。 +声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ +のメンバーが個人のデバイスで簡単に実行し、微調整することができると +信じています。これにより、私を個人の音声アシスタントとして活用できます。 +``` + + + + + + + + + + + + + + + + +
说话人输入音频合成音频
随机说话人 -
diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..d9e931429835cf454fd1a4e027b23bbee4875b65 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +CUDA_ENABLED=${CUDA_ENABLED:-true} +DEVICE="" + +if [ "${CUDA_ENABLED}" != "true" ]; then + DEVICE="--device cpu" +fi + +exec python tools/webui.py ${DEVICE} diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3 --- /dev/null +++ b/fish_speech/callbacks/__init__.py @@ -0,0 +1,3 @@ +from .grad_norm import GradNormMonitor + +__all__ = ["GradNormMonitor"] diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a --- /dev/null +++ b/fish_speech/callbacks/grad_norm.py @@ -0,0 +1,113 @@ +from typing import Optional, Union + +import lightning.pytorch as pl +import torch +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from torch import Tensor, nn +from torch.utils._foreach_utils import ( + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def grad_norm( + parameters: Union[Tensor, list[Tensor]], + norm_type: float = 2.0, +) -> float: + """ + Returns the norm of the gradients of the given parameters. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ # noqa: E501 + + if isinstance(parameters, Tensor): + parameters = [parameters] + + grads = [p.grad for p in parameters if p.grad is not None] + if len(grads) == 0: + return None + + first_device = grads[0].device + grouped_grads: dict[ + tuple[torch.device, torch.dtype], list[list[Tensor]] + ] = _group_tensors_by_device_and_dtype( + [[g.detach() for g in grads]] + ) # type: ignore[assignment] + + norms = [] + for (device, _), ([grads], _) in grouped_grads.items(): + if _has_foreach_support(grads, device=device): + norms.extend(torch._foreach_norm(grads, norm_type)) + else: + norms.extend([torch.norm(g, norm_type) for g in grads]) + + return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + +class GradNormMonitor(Callback): + """ + Callback that computes the gradient norm of the model parameters. + """ + + def __init__( + self, + norm_type: float = 2.0, + logging_interval: str = "step", + sub_module: Optional[Union[str, list[str]]] = None, + ) -> None: + """ + Args: + norm_type (float): type of the used p-norm. + logging_interval (str): "step" or "epoch". + """ + super().__init__() + + self.norm_type = norm_type + self.logging_interval = logging_interval + self.sub_module = sub_module + + def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: + """ + Computes the gradient norm of the model parameters and logs it to the logger. + + Args: + trainer (Trainer): The trainer object + model (LightningModule): The current lightningModule + """ + + lightning_model = model + + if self.sub_module is None: + return self.log_sub_module_grad_norm(lightning_model, model, "") + + sub_modules = self.sub_module + if isinstance(sub_modules, str): + sub_modules = [sub_modules] + + for sub_module in sub_modules: + self.log_sub_module_grad_norm( + lightning_model, getattr(model, sub_module), f"/{sub_module}" + ) + + def log_sub_module_grad_norm( + self, lightning_model: LightningModule, model: nn.Module, path: str + ) -> None: + grad_norm_val = grad_norm(model.parameters(), self.norm_type) + if grad_norm_val is None: + return + + on_step = self.logging_interval == "step" + lightning_model.log( + f"train{path}/grad_norm", + grad_norm_val, + on_step=on_step, + on_epoch=not on_step, + ) diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75 --- /dev/null +++ b/fish_speech/configs/base.yaml @@ -0,0 +1,87 @@ +# Base configuration for training a model +paths: + run_dir: results/${project} + ckpt_dir: ${paths.run_dir}/checkpoints + +hydra: + run: + dir: ${paths.run_dir} + +# Lightning Trainer +trainer: + _target_: lightning.pytorch.trainer.Trainer + + default_root_dir: ${paths.run_dir} + accelerator: gpu + num_nodes: 1 + devices: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + process_group_backend: nccl # This should be override when training on windows + + precision: bf16-mixed + + # disable validation by epoch end + check_val_every_n_epoch: null + val_check_interval: 5000 + max_steps: 100_000 + + # Use torch.backends.cudnn.benchmark to speed up training + benchmark: true + +# Callbacks +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.ckpt_dir} + filename: "step_{step:09d}" + save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save 5 latest checkpoints + monitor: step # use step to monitor checkpoints + mode: max # save the latest checkpoint with the highest global_step + every_n_epochs: null # don't save checkpoints by epoch end + every_n_train_steps: 5000 # save checkpoints every 5000 steps + auto_insert_metric_name: false + + model_summary: + _target_: lightning.pytorch.callbacks.ModelSummary + max_depth: 2 # the maximum depth of layer nesting that the summary will include + + learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: step + log_momentum: false + + grad_norm_monitor: + _target_: fish_speech.callbacks.GradNormMonitor + norm_type: 2 + logging_interval: step + +# Logger +logger: + tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.run_dir}/tensorboard/" + name: null + log_graph: false + default_hp_metric: true + prefix: "" + + # wandb: + # _target_: lightning.pytorch.loggers.wandb.WandbLogger + # # name: "" # name of the run (normally generated by wandb) + # save_dir: "${paths.run_dir}" + # offline: False + # id: null # pass correct id to resume experiment! + # anonymous: null # enable anonymous logging + # project: "fish-speech" + # log_model: False # upload lightning ckpts + # prefix: "" # a string to put at the beginning of metric keys + # # entity: "" # set to name of your wandb team + # group: "" + # tags: ["vq", "hq", "finetune"] + # job_type: "" + +# Loop +train: true +test: false diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10aa8d4a522f0859ed8f541f5d48672d84b39c8f --- /dev/null +++ b/fish_speech/configs/firefly_gan_vq.yaml @@ -0,0 +1,33 @@ +_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture +spec_transform: + _target_: fish_speech.utils.spectrogram.LogMelSpectrogram + sample_rate: 44100 + n_mels: 160 + n_fft: 2048 + hop_length: 512 + win_length: 2048 +backbone: + _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder + input_channels: 160 + depths: [3, 3, 9, 3] + dims: [128, 256, 384, 512] + drop_path_rate: 0.2 + kernel_size: 7 +head: + _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator + hop_length: 512 + upsample_rates: [8, 8, 2, 2, 2] # aka. strides + upsample_kernel_sizes: [16, 16, 4, 4, 4] + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + num_mels: 512 + upsample_initial_channel: 512 + pre_conv_kernel_size: 13 + post_conv_kernel_size: 13 +quantizer: + _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize + input_dim: 512 + n_groups: 8 + n_codebooks: 1 + levels: [8, 5, 5, 5] + downsample_factor: [2, 2] diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9 --- /dev/null +++ b/fish_speech/configs/lora/r_8_alpha_16.yaml @@ -0,0 +1,4 @@ +_target_: fish_speech.models.text2semantic.lora.LoraConfig +r: 8 +lora_alpha: 16 +lora_dropout: 0.01 diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4c1993023099e122fc9e004bda55ec075ed5e1b --- /dev/null +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -0,0 +1,83 @@ +defaults: + - base + - _self_ + +project: text2semantic_finetune_dual_ar +max_length: 4096 +pretrained_ckpt_path: checkpoints/fish-speech-1.4 + +# Lightning Trainer +trainer: + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" + max_steps: 1000 + precision: bf16-true + limit_val_batches: 10 + val_check_interval: 100 + +# Dataset Configuration +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: ${pretrained_ckpt_path} + +# Dataset Configuration +train_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +val_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +data: + _target_: fish_speech.datasets.semantic.SemanticDataModule + train_dataset: ${train_dataset} + val_dataset: ${val_dataset} + num_workers: 4 + batch_size: 8 + tokenizer: ${tokenizer} + max_length: ${max_length} + +# Model Configuration +model: + _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic + model: + _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained + path: ${pretrained_ckpt_path} + load_weights: true + max_length: ${max_length} + lora_config: null + + optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0 + betas: [0.9, 0.95] + eps: 1e-5 + + lr_scheduler: + _target_: torch.optim.lr_scheduler.LambdaLR + _partial_: true + lr_lambda: + _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda + _partial_: true + num_warmup_steps: 10 + +# Callbacks +callbacks: + model_checkpoint: + every_n_train_steps: ${trainer.val_check_interval} diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ca0ef9181754eda7e6b49e01abeafbe07fb00f --- /dev/null +++ b/fish_speech/conversation.py @@ -0,0 +1,2 @@ +SEMANTIC_TOKEN = "<|semantic|>" +CODEBOOK_PAD_TOKEN_ID = 0 diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa --- /dev/null +++ b/fish_speech/datasets/concat_repeat.py @@ -0,0 +1,53 @@ +import bisect +import random +from typing import Iterable + +from torch.utils.data import Dataset, IterableDataset + + +class ConcatRepeatDataset(Dataset): + datasets: list[Dataset] + cumulative_sizes: list[int] + repeats: list[int] + + @staticmethod + def cumsum(sequence, repeats): + r, s = [], 0 + for dataset, repeat in zip(sequence, repeats): + l = len(dataset) * repeat + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): + super().__init__() + + self.datasets = list(datasets) + self.repeats = repeats + + assert len(self.datasets) > 0, "datasets should not be an empty iterable" + assert len(self.datasets) == len( + repeats + ), "datasets and repeats should have the same length" + + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatRepeatDataset does not support IterableDataset" + + self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + dataset = self.datasets[dataset_idx] + + return dataset[sample_idx % len(dataset)] diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto new file mode 100644 index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379 --- /dev/null +++ b/fish_speech/datasets/protos/text-data.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package text_data; + +message Semantics { + repeated uint32 values = 1; +} + +message Sentence { + repeated string texts = 1; + repeated Semantics semantics = 3; +} + +message TextData { + string source = 1; + string name = 2; + repeated Sentence sentences = 4; +} + +message SampledData { + string source = 1; + string name = 2; + repeated Sentence samples = 3; +} diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e --- /dev/null +++ b/fish_speech/datasets/protos/text_data_pb2.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: text-data.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_SEMANTICS"]._serialized_start = 30 + _globals["_SEMANTICS"]._serialized_end = 57 + _globals["_SENTENCE"]._serialized_start = 59 + _globals["_SENTENCE"]._serialized_end = 125 + _globals["_TEXTDATA"]._serialized_start = 127 + _globals["_TEXTDATA"]._serialized_end = 207 + _globals["_SAMPLEDDATA"]._serialized_start = 209 + _globals["_SAMPLEDDATA"]._serialized_end = 290 +# @@protoc_insertion_point(module_scope) diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107 --- /dev/null +++ b/fish_speech/datasets/protos/text_data_stream.py @@ -0,0 +1,36 @@ +import struct + +from .text_data_pb2 import TextData + + +def read_pb_stream(f): + while True: + buf = f.read(4) + if len(buf) == 0: + break + size = struct.unpack("I", buf)[0] + buf = f.read(size) + text_data = TextData() + text_data.ParseFromString(buf) + yield text_data + + +def write_pb_stream(f, text_data): + buf = text_data.SerializeToString() + f.write(struct.pack("I", len(buf))) + f.write(buf) + + +def pack_pb_stream(text_data): + buf = text_data.SerializeToString() + return struct.pack("I", len(buf)) + buf + + +def split_pb_stream(f): + while True: + head = f.read(4) + if len(head) == 0: + break + size = struct.unpack("I", head)[0] + buf = f.read(size) + yield head + buf diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418 --- /dev/null +++ b/fish_speech/datasets/semantic.py @@ -0,0 +1,496 @@ +import random +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from random import Random +from typing import Optional, Union + +import numpy as np +import pyarrow.parquet as pq +import torch +import torch.nn.functional as F +from datasets.download.streaming_download_manager import xopen +from huggingface_hub import HfApi +from lightning import LightningDataModule +from torch.distributed import get_rank, get_world_size, is_initialized +from torch.utils.data import DataLoader, IterableDataset, get_worker_info +from transformers import AutoTokenizer + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.datasets.protos.text_data_pb2 import SampledData +from fish_speech.datasets.protos.text_data_stream import read_pb_stream +from fish_speech.text.clean import clean_text +from fish_speech.utils import RankedLogger +from fish_speech.utils.braceexpand import braceexpand + +log = RankedLogger(__name__, rank_zero_only=True) + + +def split_by_rank_worker(files): + # We need to know the total number of devices + # to split the data properly + + total_devices = 1 + if is_initialized(): + total_devices = get_world_size() + + worker_info = get_worker_info() + if worker_info is not None: + total_devices *= worker_info.num_workers + + if len(files) < total_devices: + # Repeat the files N times to match the number of devices + files = files * (total_devices // len(files) + 1) + + # DDP + if is_initialized(): + files = files[get_rank() :: get_world_size()] + + # Split by worker + if worker_info is not None: + files = files[worker_info.id :: worker_info.num_workers] + + return files + + +class AutoTextSemanticInstructionDataset(IterableDataset): + """ + Auto Augment Dataset by Speaker + + 1. Random concatenate multiple sentences from the same speaker to form a longer sentence + 2. Automatically normalize the text + + For interactive mode, we use the following format (multiple sequences): + [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + + For non-interactive mode, we use the following format (one long sequence): + [INST] text [/INST] ... + """ + + def __init__( + self, + proto_files: list[str], + seed: int = 42, + interactive_prob: float = 0.5, + max_length: int = 1024, + tokenizer: AutoTokenizer = None, + use_speaker: bool | float = True, + causal: bool = True, + num_codebooks: Optional[int] = None, + skip_text_prob: float = 0.0, + ): + """ + Args: + proto_files: proto buf files if using local data + seed: random seed + interactive_prob: probability to use interactive mode + max_length: max length of the text + tokenizer: tokenizer + use_speaker: include speaker information in the prompt + causal: use causal sampling when using local data, disable will lead to random sampling + num_codebooks: number of codebooks, if None, it will be automatically detected + skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode + """ + + super().__init__() + + assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" + + self.seed = seed + self.max_length = max_length + self.tokenizer = tokenizer + self.interactive_prob = interactive_prob + self.use_speaker = use_speaker + self.proto_files = proto_files + self.causal = causal + self.num_codebooks = num_codebooks + self.skip_text_prob = skip_text_prob + + self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") + self.groups = None + + def init_mock_data_server(self): + if self.groups is not None: + return + + # Expand the proto files + expanded_proto_files = [] + for filename in self.proto_files: + for i in braceexpand(filename): + i = Path(i) + if i.is_file(): + expanded_proto_files.append(i) + elif i.is_dir(): + expanded_proto_files.extend(i.rglob("*.proto")) + expanded_proto_files.extend(i.rglob("*.protos")) + else: + raise ValueError(f"{i} is not a file or directory") + + expanded_proto_files = sorted(expanded_proto_files) + Random(self.seed).shuffle(expanded_proto_files) + + self.groups = [] + shard_proto_files = split_by_rank_worker(expanded_proto_files) + log.info( + f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" + ) + + count = 0 + for filename in shard_proto_files: + with open(filename, "rb") as f: + for text_data in read_pb_stream(f): + self.groups.append(text_data) + count += 1 + + log.info(f"Read total {count} groups of data") + + # Shuffle the lines + Random(self.seed).shuffle(self.groups) + self.group_weights = [len(i.sentences) for i in self.groups] + + def __iter__(self): + while True: + yield self.augment() + + def tokenize_sentence(self, sentence: str): + sentence = clean_text(sentence) + tokens = self.tokenizer.encode( + f"{sentence}", + max_length=10**6, + add_special_tokens=False, + truncation=False, + ) + return sentence, len(tokens) + + def sample_data(self): + if self.groups is None: + self.init_mock_data_server() + + # Shuffle unique lines, estimate that each sample is at least 20 tokens + num_samples = self.max_length // 20 + + # choice group based on their number of samples + group = random.choices(self.groups, weights=self.group_weights, k=1)[0] + + if self.causal: + # Sample in order + if num_samples >= len(group.sentences): + samples = group.sentences + else: + begin = random.randint(0, len(group.sentences) - num_samples) + samples = group.sentences[begin : begin + num_samples] + else: + samples = random.choices( + group.sentences, k=min(num_samples, len(group.sentences)) + ) + + return SampledData( + source=group.source, + name=group.name, + samples=samples, + ) + + def augment(self): + final_text, final_semantic = [], [] + response = self.sample_data() + if len(response.samples) == 0: + # Invalid group + return None + + samples = list(response.samples) + idx = 0 + use_interactive = random.random() < self.interactive_prob + + if use_interactive is False: + # Random sample based on speaker using a truncated normal distribution + a = torch.tensor([0], dtype=torch.float32) + torch.nn.init.trunc_normal_( + a, + mean=self.max_length // 2, + std=self.max_length // 4, + a=10, + b=self.max_length, + ) + remaining_tokens = a.long().item() - 4 + else: + remaining_tokens = self.max_length + + # Use speaker + if isinstance(self.use_speaker, float): + use_speaker = random.random() < self.use_speaker + else: + use_speaker = self.use_speaker + + all_tokens, all_labels = [], [] + while remaining_tokens > 0 and len(samples) > 0: + sentence = samples.pop(0) + + text = random.choice(sentence.texts) + text, length = self.tokenize_sentence(text) + remaining_tokens -= length + len(sentence.semantics[0].values) + + if use_interactive is False: + final_text.append(text) + final_semantic.append(sentence.semantics) + else: + # For interactive mode, we only apply speaker for the first sentence + # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + tokens, labels = self.pack_sentences( + sentences=[text], + semantics=[sentence.semantics], + speaker=response.name if use_speaker else None, + skip_text=random.random() < self.skip_text_prob, + ) + + all_tokens.append(tokens) + all_labels.append(labels) + + idx += 1 + + if use_interactive is False: + tokens, labels = self.pack_sentences( + final_text, + semantics=final_semantic, + speaker=response.name if use_speaker else None, + ) + all_tokens.append(tokens) + all_labels.append(labels) + + tokens = torch.cat(all_tokens, dim=1) + labels = torch.cat(all_labels, dim=1) + + # Verify that the length is correct + assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" + + data = {"tokens": tokens, "labels": labels} + + return data + + def pack_sentences( + self, + sentences: list[str], + semantics: list, + speaker: Optional[str] = None, + skip_text: bool = False, + ): + if speaker is None: + speaker = "assistant" + + cated_sentences = " ".join(sentences) + if skip_text: + cated_sentences = "<|skip_text|>" + + final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" + final_text = final_text + f"<|im_start|>{speaker}\n" + + encoded = self.tokenizer.encode( + final_text, + add_special_tokens=False, + truncation=False, + max_length=10**6, + ) + semantic_length = sum([len(i[0].values) for i in semantics]) + prompt_length = len(encoded) + num_codebooks = ( + len(semantics[0]) if self.num_codebooks is None else self.num_codebooks + ) + + # Pack the tokens and semantics (add and to semantic tokens) + tokens = ( + encoded + + [self.semantic_token_id] * semantic_length + + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) + ) + + # Codebook bos/padding: 0, eos: 1 + codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] + for segment in semantics: + for book_idx, book in zip(range(num_codebooks), segment): + for j in book.values: + codes[book_idx].append(int(j) + 1) + + for book in codes: + book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) + + tokens = [tokens] + codes + + tokens = torch.tensor(tokens, dtype=torch.long) + labels = tokens.clone() + + if skip_text: + # If text is not provided, the sentence is used for condition only, all labels are -100 + torch.fill_(labels, -100) + return tokens, labels + + # Mask out the tokens for semantic, predict semantic tokens only + # Since we don't mask out the input tokens, the language modeling still works + labels[1:, :prompt_length] = -100 + + tokens = tokens[:, :-1] + labels = labels[:, 1:] + + # Verify the padding is correct, and the last token is eos + assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() + assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() + + return tokens, labels + + +@dataclass +class TextDataCollator: + tokenizer: AutoTokenizer + max_length: int = 1024 + + def __call__(self, examples): + if "negative_tokens" in examples: + positive_examples = [] + negative_examples = [] + + for i in examples: + positive_examples.append( + { + "tokens": i["tokens"], + "labels": i["labels"], + } + ) + negative_examples.append( + { + "tokens": i["negative_tokens"], + "labels": i["negative_labels"], + } + ) + + examples = positive_examples + negative_examples + + return self.batchify(examples) + + def batchify(self, examples, tokens_key="tokens", labels_key="labels"): + tokens, attention_masks, labels = [], [], [] + + # Calculate the max length + max_tokens_length = 0 + for example in examples: + max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) + max_tokens_length = min(max_tokens_length, self.max_length) + + for example in examples: + _tokens = example[tokens_key][:, :max_tokens_length] + _labels = example[labels_key][:, :max_tokens_length] + _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) + tokens_length = _tokens.size(1) + _attention_mask[:tokens_length] = False + + assert tokens_length == _labels.size( + 1 + ), f"{tokens_length} != {_labels.size(1)}" + + if tokens_length < max_tokens_length: + _tokens = F.pad( + _tokens, + (0, max_tokens_length - tokens_length), + value=self.tokenizer.eos_token_id, + ) + _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID + _labels = F.pad( + _labels, (0, max_tokens_length - _labels.size(1)), value=-100 + ) + + tokens.append(_tokens) + attention_masks.append(_attention_mask) + labels.append(_labels) + + tokens = torch.stack(tokens, dim=0) + attention_masks = torch.stack(attention_masks, dim=0) + labels = torch.stack(labels, dim=0) + + return { + "inputs": tokens, + "attention_masks": attention_masks, + "labels": labels, + } + + +class InterleaveDataset(IterableDataset): + def __init__( + self, + datasets: list[IterableDataset], + probabilities: list[float], + seed: int = 42, + ): + super().__init__() + + self.datasets = datasets + self.probabilities = probabilities + self.seed = seed + + def __iter__(self): + rng = np.random.default_rng(self.seed) + dataset_iterators = [iter(dataset) for dataset in self.datasets] + + while True: + # Random choice one + dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) + dataset_iterator = dataset_iterators[dataset_idx] + + try: + yield next(dataset_iterator) + except StopIteration: + # Exhausted, create a new iterator + dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) + yield next(dataset_iterators[dataset_idx]) + + +class SemanticDataModule(LightningDataModule): + def __init__( + self, + train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + batch_size: int = 32, + tokenizer: AutoTokenizer = None, + max_length: int = 1024, + num_workers: int = 4, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.tokenizer = tokenizer + self.max_length = max_length + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + from tqdm import tqdm + + ds = AutoTextSemanticInstructionDataset( + ["data/protos"], + tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), + use_speaker=False, + interactive_prob=1.0, + skip_text_prob=0.5, + ) + + for i in ds: + print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) + # i["labels"][0][i["labels"][0] == -100] = 0 + # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) + break diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e --- /dev/null +++ b/fish_speech/datasets/vqgan.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import librosa +import numpy as np +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + +from fish_speech.utils import RankedLogger + +logger = RankedLogger(__name__, rank_zero_only=False) + + +class VQGANDataset(Dataset): + def __init__( + self, + filelist: str, + sample_rate: int = 32000, + hop_length: int = 640, + slice_frames: Optional[int] = None, + ): + super().__init__() + + filelist = Path(filelist) + root = filelist.parent + + self.files = [ + root / line.strip() + for line in filelist.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + self.sample_rate = sample_rate + self.hop_length = hop_length + self.slice_frames = slice_frames + + def __len__(self): + return len(self.files) + + def get_item(self, idx): + file = self.files[idx] + + audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) + + # Slice audio and features + if ( + self.slice_frames is not None + and audio.shape[0] > self.slice_frames * self.hop_length + ): + start = np.random.randint( + 0, audio.shape[0] - self.slice_frames * self.hop_length + ) + audio = audio[start : start + self.slice_frames * self.hop_length] + + if len(audio) == 0: + return None + + max_value = np.abs(audio).max() + if max_value > 1.0: + audio = audio / max_value + + return { + "audio": torch.from_numpy(audio), + } + + def __getitem__(self, idx): + try: + return self.get_item(idx) + except Exception as e: + import traceback + + traceback.print_exc() + logger.error(f"Error loading {self.files[idx]}: {e}") + return None + + +@dataclass +class VQGANCollator: + def __call__(self, batch): + batch = [x for x in batch if x is not None] + + audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) + audio_maxlen = audio_lengths.max() + + # Rounds up to nearest multiple of 2 (audio_lengths) + audios = [] + for x in batch: + audios.append( + torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) + ) + + return { + "audios": torch.stack(audios), + "audio_lengths": audio_lengths, + } + + +class VQGANDataModule(LightningDataModule): + def __init__( + self, + train_dataset: VQGANDataset, + val_dataset: VQGANDataset, + batch_size: int = 32, + num_workers: int = 4, + val_batch_size: Optional[int] = None, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.val_batch_size = val_batch_size or batch_size + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + shuffle=True, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") + dataloader = DataLoader( + dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() + ) + + for batch in dataloader: + print(batch["audios"].shape) + print(batch["features"].shape) + print(batch["audio_lengths"]) + print(batch["feature_lengths"]) + break diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md new file mode 100644 index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2 --- /dev/null +++ b/fish_speech/i18n/README.md @@ -0,0 +1,27 @@ +## i18n Folder Attribution + +The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: + +### fish_speech/i18n/core.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) + +**Initial commit:** +add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) + +**Initial author:** +[@L4Ph](https://github.com/L4Ph) + +### fish_speech/i18n/scan.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) + +**Initial commit:** +File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) + +**Initial author:** +[@towzeur](https://github.com/towzeur) + +We appreciate the contributions of the RVC project and its authors. diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246 --- /dev/null +++ b/fish_speech/i18n/__init__.py @@ -0,0 +1,3 @@ +from .core import i18n + +__all__ = ["i18n"] diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py new file mode 100644 index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd --- /dev/null +++ b/fish_speech/i18n/core.py @@ -0,0 +1,40 @@ +import json +import locale +from pathlib import Path + +I18N_FILE_PATH = Path(__file__).parent / "locale" +DEFAULT_LANGUAGE = "en_US" + + +def load_language_list(language): + with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: + language_list = json.load(f) + + return language_list + + +class I18nAuto: + def __init__(self): + i18n_file = Path(".locale") + + if i18n_file.exists(): + with open(i18n_file, "r", encoding="utf-8") as f: + language = f.read().strip() + else: + # getlocale can't identify the system's language ((None, None)) + language = locale.getdefaultlocale()[0] + + if (I18N_FILE_PATH / f"{language}.json").exists() is False: + language = DEFAULT_LANGUAGE + + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + + +i18n = I18nAuto() diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json new file mode 100644 index 0000000000000000000000000000000000000000..6e280c236e9c79de2087ec33c7bf6f8e1a5296c4 --- /dev/null +++ b/fish_speech/i18n/locale/en_US.json @@ -0,0 +1,122 @@ +{ + "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Accumulate Gradient Batches", + "Add to Processing Area": "Add to Processing Area", + "Added path successfully!": "Added path successfully!", + "Advanced Config": "Advanced Config", + "Base LLAMA Model": "Base LLAMA Model", + "Batch Inference": "Batch Inference", + "Batch Size": "Batch Size", + "Changing with the Model Path": "Changing with the Model Path", + "Chinese": "Chinese", + "Compile Model": "Compile Model", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", + "Copy": "Copy", + "Data Preprocessing": "Data Preprocessing", + "Data Preprocessing Path": "Data Preprocessing Path", + "Data Source": "Data Source", + "Decoder Model Config": "Decoder Model Config", + "Decoder Model Path": "Decoder Model Path", + "Disabled": "Disabled", + "Enable Reference Audio": "Enable Reference Audio", + "English": "English", + "Error Message": "Error Message", + "File Preprocessing": "File Preprocessing", + "Generate": "Generate", + "Generated Audio": "Generated Audio", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", + "Infer interface is closed": "Infer interface is closed", + "Inference Configuration": "Inference Configuration", + "Inference Server Configuration": "Inference Server Configuration", + "Inference Server Error": "Inference Server Error", + "Inferring interface is launched at {}": "Inferring interface is launched at {}", + "Initial Learning Rate": "Initial Learning Rate", + "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", + "Input Text": "Input Text", + "Invalid path: {}": "Invalid path: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", + "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", + "Japanese": "Japanese", + "LLAMA Configuration": "LLAMA Configuration", + "LLAMA Model Config": "LLAMA Model Config", + "LLAMA Model Path": "LLAMA Model Path", + "Labeling Device": "Labeling Device", + "LoRA Model to be merged": "LoRA Model to be merged", + "Maximum Audio Duration": "Maximum Audio Duration", + "Maximum Length per Sample": "Maximum Length per Sample", + "Maximum Training Steps": "Maximum Training Steps", + "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", + "Merge": "Merge", + "Merge LoRA": "Merge LoRA", + "Merge successfully": "Merge successfully", + "Minimum Audio Duration": "Minimum Audio Duration", + "Model Output Path": "Model Output Path", + "Model Size": "Model Size", + "Move": "Move", + "Move files successfully": "Move files successfully", + "No audio generated, please check the input text.": "No audio generated, please check the input text.", + "No selected options": "No selected options", + "Number of Workers": "Number of Workers", + "Open Inference Server": "Open Inference Server", + "Open Labeler WebUI": "Open Labeler WebUI", + "Open Tensorboard": "Open Tensorboard", + "Opened labeler in browser": "Opened labeler in browser", + "Optional Label Language": "Optional Label Language", + "Optional online ver": "Optional online ver", + "Output Path": "Output Path", + "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", + "Precision": "Precision", + "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", + "Put your text here.": "Put your text here.", + "Reference Audio": "Reference Audio", + "Reference Text": "Reference Text", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", + "Remove Selected Data": "Remove Selected Data", + "Removed path successfully!": "Removed path successfully!", + "Repetition Penalty": "Repetition Penalty", + "Save model every n steps": "Save model every n steps", + "Select LLAMA ckpt": "Select LLAMA ckpt", + "Select VITS ckpt": "Select VITS ckpt", + "Select VQGAN ckpt": "Select VQGAN ckpt", + "Select source file processing method": "Select source file processing method", + "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", + "Selected: {}": "Selected: {}", + "Speaker": "Speaker", + "Speaker is identified by the folder name": "Speaker is identified by the folder name", + "Start Training": "Start Training", + "Streaming Audio": "Streaming Audio", + "Streaming Generate": "Streaming Generate", + "Tensorboard Host": "Tensorboard Host", + "Tensorboard Log Path": "Tensorboard Log Path", + "Tensorboard Port": "Tensorboard Port", + "Tensorboard interface is closed": "Tensorboard interface is closed", + "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", + "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", + "Training Configuration": "Training Configuration", + "Training Error": "Training Error", + "Training stopped": "Training stopped", + "Type name of the speaker": "Type name of the speaker", + "Type the path or select from the dropdown": "Type the path or select from the dropdown", + "Use LoRA": "Use LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", + "Use filelist": "Use filelist", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", + "VITS Configuration": "VITS Configuration", + "VQGAN Configuration": "VQGAN Configuration", + "Validation Batch Size": "Validation Batch Size", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", + "WebUI Host": "WebUI Host", + "WebUI Port": "WebUI Port", + "Whisper Model": "Whisper Model", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", + "latest": "latest", + "new": "new", + "Realtime Transform Text": "Realtime Transform Text", + "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", + "Text Normalization": "Text Normalization" +} diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json new file mode 100644 index 0000000000000000000000000000000000000000..3285341f6893fe3e2ccbee6490dd8c90ed21854e --- /dev/null +++ b/fish_speech/i18n/locale/es_ES.json @@ -0,0 +1,122 @@ +{ + "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular lotes de gradientes", + "Add to Processing Area": "Agregar al Área de Procesamiento", + "Added path successfully!": "¡Ruta agregada exitosamente!", + "Advanced Config": "Configuración Avanzada", + "Base LLAMA Model": "Modelo Base LLAMA", + "Batch Inference": "Inferencia por Lote", + "Batch Size": "Tamaño del Lote", + "Changing with the Model Path": "Cambiando con la Ruta del Modelo", + "Chinese": "Chino", + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", + "Copy": "Copiar", + "Data Preprocessing": "Preprocesamiento de Datos", + "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", + "Data Source": "Fuente de Datos", + "Decoder Model Config": "Configuración del modelo decodificador", + "Decoder Model Path": "Ruta del modelo decodificador", + "Disabled": "Desactivado", + "Enable Reference Audio": "Habilitar Audio de Referencia", + "English": "Inglés", + "Error Message": "Mensaje de Error", + "File Preprocessing": "Preprocesamiento de Archivos", + "Generate": "Generar", + "Generated Audio": "Audio Generado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", + "Infer interface is closed": "La interfaz de inferencia está cerrada", + "Inference Configuration": "Configuración de Inferencia", + "Inference Server Configuration": "Configuración del Servidor de Inferencia", + "Inference Server Error": "Error del Servidor de Inferencia", + "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", + "Initial Learning Rate": "Tasa de Aprendizaje Inicial", + "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Ruta inválida: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", + "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", + "Japanese": "Japonés", + "LLAMA Configuration": "Configuración de LLAMA", + "LLAMA Model Config": "Configuración del Modelo LLAMA", + "LLAMA Model Path": "Ruta del Modelo LLAMA", + "Labeling Device": "Dispositivo de Etiquetado", + "LoRA Model to be merged": "Modelo LoRA a fusionar", + "Maximum Audio Duration": "Duración máxima de audio", + "Maximum Length per Sample": "Longitud Máxima por Muestra", + "Maximum Training Steps": "Pasos Máximos de Entrenamiento", + "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", + "Merge": "Fusionar", + "Merge LoRA": "Fusionar LoRA", + "Merge successfully": "Fusionado exitosamente", + "Minimum Audio Duration": "Duración mínima de audio", + "Model Output Path": "Ruta de Salida del Modelo", + "Model Size": "Tamaño del Modelo", + "Move": "Mover", + "Move files successfully": "Archivos movidos exitosamente", + "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", + "No selected options": "No hay opciones seleccionadas", + "Number of Workers": "Número de Trabajadores", + "Open Inference Server": "Abrir Servidor de Inferencia", + "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "Se abrió el etiquetador en el navegador", + "Optional Label Language": "Idioma de Etiquetado Opcional", + "Optional online ver": "Ver en línea opcional", + "Output Path": "Ruta de Salida", + "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", + "Precision": "Precisión", + "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", + "Put your text here.": "Ponga su texto aquí.", + "Reference Audio": "Audio de Referencia", + "Reference Text": "Texto de Referencia", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", + "Remove Selected Data": "Eliminar Datos Seleccionados", + "Removed path successfully!": "¡Ruta eliminada exitosamente!", + "Repetition Penalty": "Penalización por Repetición", + "Save model every n steps": "Guardar modelo cada n pasos", + "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", + "Select VITS ckpt": "Seleccionar punto de control VITS", + "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", + "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", + "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", + "Selected: {}": "Seleccionado: {}", + "Speaker": "Hablante", + "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", + "Start Training": "Iniciar Entrenamiento", + "Streaming Audio": "transmisión de audio", + "Streaming Generate": "síntesis en flujo", + "Tensorboard Host": "Host de Tensorboard", + "Tensorboard Log Path": "Ruta de Registro de Tensorboard", + "Tensorboard Port": "Puerto de Tensorboard", + "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", + "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", + "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", + "Training Configuration": "Configuración de Entrenamiento", + "Training Error": "Error de Entrenamiento", + "Training stopped": "Entrenamiento detenido", + "Type name of the speaker": "Escriba el nombre del hablante", + "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", + "Use filelist": "Usar lista de archivos", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", + "VITS Configuration": "Configuración de VITS", + "VQGAN Configuration": "Configuración de VQGAN", + "Validation Batch Size": "Tamaño del Lote de Validación", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", + "WebUI Host": "Host de WebUI", + "WebUI Port": "Puerto de WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", + "latest": "más reciente", + "new": "nuevo", + "Realtime Transform Text": "Transformación de Texto en Tiempo Real", + "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", + "Text Normalization": "Normalización de Texto" +} diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json new file mode 100644 index 0000000000000000000000000000000000000000..d30bac7bcdf4f4c65b1f78b4dcf9d705c1d8eb39 --- /dev/null +++ b/fish_speech/i18n/locale/ja_JP.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", + "Accumulate Gradient Batches": "勾配バッチの累積", + "Add to Processing Area": "処理エリアに追加", + "Added path successfully!": "パスの追加に成功しました!", + "Advanced Config": "詳細設定", + "Base LLAMA Model": "基本LLAMAモデル", + "Batch Inference": "バッチ推論", + "Batch Size": "バッチサイズ", + "Changing with the Model Path": "モデルのパスに伴って変化する", + "Chinese": "中国語", + "Compile Model": "モデルのコンパイル", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", + "Copy": "コピー", + "Data Preprocessing": "データ前処理", + "Data Preprocessing Path": "データ前処理パス", + "Data Source": "データソース", + "Decoder Model Config": "デコーダーモデルの構成", + "Decoder Model Path": "デコーダーモデルのパス", + "Disabled": "無効", + "Enable Reference Audio": "リファレンスオーディオを有効にする", + "English": "英語", + "Error Message": "エラーメッセージ", + "File Preprocessing": "文書前处理", + "Generate": "生成", + "Generated Audio": "生成されたオーディオ", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", + "Infer interface is closed": "推論インターフェースが閉じられています", + "Inference Configuration": "推論設定", + "Inference Server Configuration": "推論サーバー設定", + "Inference Server Error": "推論サーバーエラー", + "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", + "Initial Learning Rate": "初期学習率", + "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", + "Input Text": "入力テキスト", + "Invalid path: {}": "無効なパス: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", + "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", + "Japanese": "日本語", + "LLAMA Configuration": "LLAMA設定", + "LLAMA Model Config": "LLAMAモデル設定", + "LLAMA Model Path": "LLAMAモデルパス", + "Labeling Device": "ラベリングデバイス", + "LoRA Model to be merged": "マージするLoRAモデル", + "Maximum Audio Duration": "最大オーディオの長さ", + "Maximum Length per Sample": "サンプルあたりの最大長", + "Maximum Training Steps": "最大トレーニングステップ数", + "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", + "Merge": "マージ", + "Merge LoRA": "LoRAのマージ", + "Merge successfully": "マージに成功しました", + "Minimum Audio Duration": "最小オーディオの長さ", + "Model Output Path": "モデル出力パス", + "Model Size": "モデルサイズ", + "Move": "移動", + "Move files successfully": "ファイルの移動に成功しました", + "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", + "No selected options": "選択されたオプションはありません", + "Number of Workers": "ワーカー数", + "Open Inference Server": "推論サーバーを開く", + "Open Labeler WebUI": "ラベラーWebUIを開く", + "Open Tensorboard": "Tensorboardを開く", + "Opened labeler in browser": "ブラウザでラベラーを開きました", + "Optional Label Language": "オプションのラベル言語", + "Optional online ver": "オプションのオンラインバージョン", + "Output Path": "出力パス", + "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", + "Precision": "精度", + "Probability of applying Speaker Condition": "話者条件を適用する確率", + "Put your text here.": "ここにテキストを入力してください。", + "Reference Audio": "リファレンスオーディオ", + "Reference Text": "リファレンステキスト", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", + "Remove Selected Data": "選択したデータを削除", + "Removed path successfully!": "パスの削除に成功しました!", + "Repetition Penalty": "反復ペナルティ", + "Save model every n steps": "nステップごとにモデルを保存", + "Select LLAMA ckpt": " LLAMA チェックポイントを選択", + "Select VITS ckpt": "VITS チェックポイントを選択", + "Select VQGAN ckpt": "VQGAN チェックポイントを選択", + "Select source file processing method": "ソースファイルの処理方法を選択", + "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", + "Selected: {}": "選択済み: {}", + "Speaker": "話者", + "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", + "Start Training": "トレーニング開始", + "Streaming Audio": "ストリーミングオーディオ", + "Streaming Generate": "ストリーミング合成", + "Tensorboard Host": "Tensorboardホスト", + "Tensorboard Log Path": "Tensorboardログパス", + "Tensorboard Port": "Tensorboardポート", + "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", + "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", + "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", + "Training Configuration": "トレーニング設定", + "Training Error": "トレーニングエラー", + "Training stopped": "トレーニングが停止しました", + "Type name of the speaker": "話者の名前を入力", + "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", + "Use LoRA": "LoRAを使用", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", + "Use filelist": "ファイルリストを使用", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", + "VITS Configuration": "VITS の構成", + "VQGAN Configuration": "VQGAN の構成", + "Validation Batch Size": "検証バッチサイズ", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", + "WebUI Host": "WebUIホスト", + "WebUI Port": "WebUIポート", + "Whisper Model": "Whisperモデル", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", + "latest": "最新", + "new": "新規", + "Realtime Transform Text": "リアルタイム変換テキスト", + "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", + "Text Normalization": "テキスト正規化" + +} diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json new file mode 100644 index 0000000000000000000000000000000000000000..385f20272e19053ab9b6cf6463a84c8ece768c68 --- /dev/null +++ b/fish_speech/i18n/locale/pt_BR.json @@ -0,0 +1,133 @@ +{ + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", + "Add to Processing Area": "Adicionar à Área de Processamento", + "Added path successfully!": "Caminho adicionado com sucesso!", + "Advanced Config": "Configuração Avançada", + "Base LLAMA Model": "Modelo LLAMA Base", + "Batch Inference": "Inferência em Lote", + "Batch Size": "Tamanho do Lote", + "Changing with the Model Path": "Alterando com o Caminho do Modelo", + + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", + "Copy": "Copiar", + "Data Preprocessing": "Pré-processamento de Dados", + "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", + "Data Source": "Fonte de Dados", + "Decoder Model Config": "Configuração do Modelo Decodificador", + "Decoder Model Path": "Caminho do Modelo Decodificador", + "Disabled": "Desativado", + "Enable Initial Prompt": "Habilitar Prompt Inicial", + "Enable Reference Audio": "Habilitar Áudio de Referência", + "English": "Inglês", + "Japanese": "Japonês", + "Chinese": "Chinês", + "Portuguese": "Português", + "Spanish": "Espanhol", + "Error Message": "Mensagem de Erro", + "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", + "File Preprocessing": "Pré-processamento de Arquivos", + "Generate": "Gerar", + "Generated Audio": "Áudio Gerado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", + "Infer interface is closed": "A interface de inferência foi fechada", + "Inference Configuration": "Configuração de Inferência", + "Inference Server Configuration": "Configuração do Servidor de Inferência", + "Inference Server Error": "Erro do Servidor de Inferência", + "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", + "Initial Learning Rate": "Taxa de Aprendizagem Inicial", + "Initial Prompt": "Prompt Inicial", + "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", + "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Caminho inválido: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", + "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", + "LLAMA Configuration": "Configuração do LLAMA", + "LLAMA Model Config": "Configuração do Modelo LLAMA", + "LLAMA Model Path": "Caminho do Modelo LLAMA", + "Labeling Device": "Dispositivo de Rotulagem", + "LoRA Model to be merged": "Modelo LoRA para mesclagem", + "Maximum Length per Sample": "Comprimento Máximo por Amostra", + "Maximum Training Steps": "Etapas Máximas de Treinamento", + "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", + "Merge": "Mesclar", + "Merge LoRA": "Mesclar LoRA", + "Merge successfully": "Mesclado com sucesso", + "Model Output Path": "Caminho de Saída do Modelo", + "Model Quantization": "Quantização do Modelo", + "Model Size": "Tamanho do Modelo", + "Move": "Mover", + "Move files successfully": "Arquivos movidos com sucesso", + "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", + "No selected options": "Nenhuma opção selecionada", + "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", + "Number of Workers": "Número de Processos", + "Open Inference Server": "Abrir Servidor de Inferência", + "Open Labeler WebUI": "Abrir WebUI de Rotulagem", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", + "Optional Label Language": "Idioma do Rótulo (Opcional)", + "Optional online ver": "Versão online (opcional)", + "Output Path": "Caminho de Saída", + "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", + "Post-quantification Precision": "Precisão Pós-quantização", + "Precision": "Precisão", + "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", + "Put your text here.": "Insira seu texto aqui.", + "Quantify": "Quantizar", + "Quantify successfully": "Quantizado com sucesso", + "Realtime Transform Text": "Transformar Texto em Tempo Real", + "Reference Audio": "Áudio de Referência", + "Reference Text": "Texto de Referência", + "warning": "Aviso", + "Pre-processing begins...": "O pré-processamento começou!", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", + "Remove Selected Data": "Remover Dados Selecionados", + "Removed path successfully!": "Caminho removido com sucesso!", + "Repetition Penalty": "Penalidade de Repetição", + "Save model every n steps": "Salvar modelo a cada n etapas", + "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", + "Select source file processing method": "Escolha como processar o arquivo de origem", + "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", + "Selected: {}": "Selecionado: {}", + "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", + "Start Training": "Iniciar Treinamento", + "Streaming Audio": "Áudio em Streaming", + "Streaming Generate": "Geração em Streaming", + "Tensorboard Host": "Host do Tensorboard", + "Tensorboard Log Path": "Caminho de Log do Tensorboard", + "Tensorboard Port": "Porta do Tensorboard", + "Tensorboard interface is closed": "A interface do Tensorboard está fechada", + "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", + "Text Normalization": "Normalização de Texto", + "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", + "Training Configuration": "Configuração de Treinamento", + "Training Error": "Erro de Treinamento", + "Training stopped": "Treinamento interrompido!", + "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", + "Use filelist": "Usar lista de arquivos", + "VQGAN Configuration": "Configuração do VQGAN", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", + "WebUI Host": "Host da WebUI", + "WebUI Port": "Porta da WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", + "auto": "automático", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", + "latest": "mais recente", + "new": "novo", + "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", + "You don't need to train this model!": "Não é necessário treinar este modelo!", + "Yes": "Sim", + "No": "Não", + "version:": "versão:", + "author:": "autor:" +} diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json new file mode 100644 index 0000000000000000000000000000000000000000..3dd1a5cd1ccf3860ca508238cc64a68ca4fc3276 --- /dev/null +++ b/fish_speech/i18n/locale/zh_CN.json @@ -0,0 +1,122 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", + "Accumulate Gradient Batches": "梯度累积批次", + "Add to Processing Area": "加入处理区", + "Added path successfully!": "添加路径成功!", + "Advanced Config": "高级参数", + "Base LLAMA Model": "基础 LLAMA 模型", + "Batch Inference": "批量推理", + "Batch Size": "批次大小", + "Changing with the Model Path": "随模型路径变化", + "Chinese": "中文", + "Compile Model": "编译模型", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", + "Copy": "复制", + "Data Preprocessing": "数据预处理", + "Data Preprocessing Path": "数据预处理路径", + "Data Source": "数据源", + "Decoder Model Config": "解码器模型配置", + "Decoder Model Path": "解码器模型路径", + "Disabled": "禁用", + "Enable Reference Audio": "启用参考音频", + "English": "英文", + "Error Message": "错误信息", + "File Preprocessing": "文件预处理", + "Generate": "生成", + "Generated Audio": "音频", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", + "Infer interface is closed": "推理界面已关闭", + "Inference Configuration": "推理配置", + "Inference Server Configuration": "推理服务器配置", + "Inference Server Error": "推理服务器错误", + "Inferring interface is launched at {}": "推理界面已在 {} 上启动", + "Initial Learning Rate": "初始学习率", + "Input Audio & Source Path for Transcription": "输入音频和转录源路径", + "Input Text": "输入文本", + "Invalid path: {}": "无效路径: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", + "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", + "Japanese": "日文", + "LLAMA Configuration": "LLAMA 配置", + "LLAMA Model Config": "LLAMA 模型配置", + "LLAMA Model Path": "LLAMA 模型路径", + "Labeling Device": "标注加速设备", + "LoRA Model to be merged": "要合并的 LoRA 模型", + "Maximum Audio Duration": "最大音频时长", + "Maximum Length per Sample": "每个样本的最大长度", + "Maximum Training Steps": "最大训练步数", + "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", + "Merge": "合并", + "Merge LoRA": "合并 LoRA", + "Merge successfully": "合并成功", + "Minimum Audio Duration": "最小音频时长", + "Model Output Path": "模型输出路径", + "Model Size": "模型规模", + "Move": "移动", + "Move files successfully": "移动文件成功", + "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", + "No selected options": "没有选择的选项", + "Number of Workers": "数据加载进程数", + "Open Inference Server": "打开推理服务器", + "Open Labeler WebUI": "打开标注工具", + "Open Tensorboard": "打开 Tensorboard", + "Opened labeler in browser": "在浏览器中打开标注工具", + "Optional Label Language": "[可选] 标注语言", + "Optional online ver": "[可选] 使用在线版", + "Output Path": "输出路径", + "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", + "Precision": "精度", + "Probability of applying Speaker Condition": "应用说话人条件的概率", + "Put your text here.": "在此处输入文本.", + "Reference Audio": "参考音频", + "Reference Text": "参考文本", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", + "Remove Selected Data": "移除选中数据", + "Removed path successfully!": "移除路径成功!", + "Repetition Penalty": "重复惩罚", + "Save model every n steps": "每 n 步保存模型", + "Select LLAMA ckpt": "选择 LLAMA 检查点", + "Select VITS ckpt": "选择 VITS 检查点", + "Select VQGAN ckpt": "选择 VQGAN 检查点", + "Select source file processing method": "选择源文件处理方法", + "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", + "Selected: {}": "已选择: {}", + "Speaker": "说话人", + "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", + "Start Training": "开始训练", + "Streaming Audio": "流式音频", + "Streaming Generate": "流式合成", + "Tensorboard Host": "Tensorboard 监听地址", + "Tensorboard Log Path": "Tensorboard 日志路径", + "Tensorboard Port": "Tensorboard 端口", + "Tensorboard interface is closed": "Tensorboard 界面已关闭", + "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", + "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", + "Training Configuration": "训练配置", + "Training Error": "训练错误", + "Training stopped": "训练已停止", + "Type name of the speaker": "输入说话人的名称", + "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", + "Use LoRA": "使用 LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", + "Use filelist": "使用文件列表", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", + "VITS Configuration": "VITS 配置", + "VQGAN Configuration": "VQGAN 配置", + "Validation Batch Size": "验证批次大小", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", + "WebUI Host": "WebUI 监听地址", + "WebUI Port": "WebUI 端口", + "Whisper Model": "Whisper 模型", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", + "latest": "最近的检查点", + "new": "创建新的检查点", + "Realtime Transform Text": "实时规范化文本", + "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", + "Text Normalization": "文本规范化" +} diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py new file mode 100644 index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1 --- /dev/null +++ b/fish_speech/i18n/scan.py @@ -0,0 +1,122 @@ +import ast +import glob +import json +from collections import OrderedDict +from pathlib import Path + +from loguru import logger + +from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH + + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + + +# scan the directory for all .py files (recursively) +# for each file, parse the code into an AST +# for each AST, extract the i18n strings + +strings = [] +folders = ["fish_speech", "tools"] +# for filename in glob.iglob("**/*.py", recursive=True): +for folder in folders: + for f in Path(folder).rglob("*.py"): + code = f.read_text(encoding="utf-8") + if "i18n(" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") + strings.extend(i18n_strings) + +code_keys = set(strings) +logger.info(f"Total unique: {len(code_keys)}") + + +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) +standard_keys = set(standard_data.keys()) + +# Define the standard file name +unused_keys = standard_keys - code_keys +logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") +for unused_key in unused_keys: + logger.info(f"\t{unused_key}") + +missing_keys = code_keys - standard_keys +logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") +for missing_key in missing_keys: + logger.info(f"\t{missing_key}") + +code_keys_dict = OrderedDict() +for s in strings: + code_keys_dict[s] = s + +# write back +with open(standard_file, "w", encoding="utf-8") as f: + json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + +logger.info(f"Updated {standard_file}") + + +# Define the standard file name +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" + +# Find all JSON files in the directory +dir_path = I18N_FILE_PATH +languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] + +# Load the standard file +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) + +# Loop through each language file +for lang_file in languages: + # Load the language file + with open(lang_file, "r", encoding="utf-8") as f: + lang_data = json.load(f, object_pairs_hook=OrderedDict) + + # Find the difference between the language file and the standard file + diff = set(standard_data.keys()) - set(lang_data.keys()) + + miss = set(lang_data.keys()) - set(standard_data.keys()) + + # Add any missing keys to the language file + for key in diff: + lang_data[key] = "#!" + key + logger.info(f"Added missing key: {key} to {lang_file}") + + # Del any extra keys to the language file + for key in miss: + del lang_data[key] + logger.info(f"Del extra key: {key} from {lang_file}") + + # Sort the keys of the language file to match the order of the standard file + lang_data = OrderedDict( + sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) + ) + + # Save the updated language file + with open(lang_file, "w", encoding="utf-8") as f: + json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + + logger.info(f"Updated {lang_file}") + +logger.info("Done") diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py new file mode 100644 index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0 --- /dev/null +++ b/fish_speech/models/text2semantic/lit_module.py @@ -0,0 +1,202 @@ +from typing import Any, Optional + +import lightning as L +import torch +import torch.nn.functional as F +from lightning.pytorch.utilities.types import OptimizerLRScheduler + +import fish_speech.utils as utils +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.models.text2semantic.llama import NaiveTransformer + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +class TextToSemantic(L.LightningModule): + def __init__( + self, + model: NaiveTransformer, + optimizer: Any, + lr_scheduler: Any, + ): + super().__init__() + + self.model = model + self.optimizer_builder = optimizer + self.lr_scheduler_builder = lr_scheduler + + def forward(self, x): + return self.model(x) + + def on_save_checkpoint(self, checkpoint): + # Save only LoRA parameters + state_dict = checkpoint["state_dict"] + use_lora = any("lora" in name for name in state_dict.keys()) + if not use_lora: + return + + for name in list(state_dict.keys()): + if "lora" not in name: + state_dict.pop(name) + + def configure_optimizers(self) -> OptimizerLRScheduler: + # Get weight decay parameters + weight_decay_parameters, other_parameters = [], [] + for name, param in self.named_parameters(): + if ".bias" in name or "norm.weight" in name or ".embeddings." in name: + other_parameters.append(param) + else: + weight_decay_parameters.append(param) + + optimizer = self.optimizer_builder( + [ + {"params": weight_decay_parameters}, + {"params": other_parameters, "weight_decay": 0.0}, + ] + ) + + # Print the parameters and their weight decay + for i in optimizer.param_groups: + log.info( + f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" + ) + + lr_scheduler = self.lr_scheduler_builder(optimizer) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + }, + } + + # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert logits.shape[:-1] == labels.shape + + labels = labels.clone() + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def _step(self, batch, batch_idx, stage: str): + is_train = stage == "train" + + if is_train: + # Key part to make lora work + # Otherwise the parameters are merged, which lead to incorrect gradients + self.model.train() + + # Do positive and negative samples in the same batch to speed up training + labels = batch["labels"] + outputs = self.model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.view(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + ) + + codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.view(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + ) + + loss = base_loss + semantic_loss + + self.log( + f"{stage}/loss", + loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/base_loss", + base_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/semantic_loss", + semantic_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + # Top-5 accuracy + accuracy = self.get_accuracy(codebook_logits, codebook_labels) + self.log( + f"{stage}/top_5_accuracy", + accuracy, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + return loss + + def get_accuracy(self, logits, labels): + mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) + if mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + _, indices = logits.topk(5, dim=-1) + correct = indices.eq(labels.unsqueeze(-1)) + correct[~mask] = 0 + correct = correct.sum() + accuracy = correct / mask.sum() + + return accuracy + + def training_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "val") diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..0725dfb9b78b1154753641b69c959a2faadba48c --- /dev/null +++ b/fish_speech/models/text2semantic/llama.py @@ -0,0 +1,779 @@ +import json +import math +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from loguru import logger +from torch import Tensor +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.checkpoint import checkpoint +from transformers import AutoTokenizer + +from fish_speech.conversation import SEMANTIC_TOKEN +from fish_speech.utils import RankedLogger + +from .lora import LoraConfig, setup_lora + +log = RankedLogger(__name__, rank_zero_only=True) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class BaseModelArgs: + model_type: str = "base" + + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + max_seq_len: int = 2048 + dropout: float = 0.0 + tie_word_embeddings: bool = True + attention_qkv_bias: bool = False + + # Codebook configs + codebook_size: int = 160 + num_codebooks: int = 4 + + # Gradient checkpointing + use_gradient_checkpointing: bool = True + + # Initialize the model + initializer_range: float = 0.02 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + if path.is_dir(): + path = path / "config.json" + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + match data["model_type"]: + case "naive": + cls = NaiveModelArgs + case "dual_ar": + cls = DualARModelArgs + case _: + raise ValueError(f"Unknown model type: {data['model_type']}") + + return cls(**data) + + def save(self, path: str): + with open(path, "w") as f: + json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) + + +@dataclass +class NaiveModelArgs(BaseModelArgs): + model_type: str = "naive" + + +@dataclass +class DualARModelArgs(BaseModelArgs): + model_type: str = "dual_ar" + n_fast_layer: int = 4 + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +@dataclass +class TransformerForwardResult: + token_logits: Tensor + codebook_logits: Tensor + + +@dataclass +class BaseTransformerForwardResult: + logits: Tensor + hidden_states: Tensor + + +class BaseTransformer(nn.Module): + def __init__( + self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer + + self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN) + + # Slow transformer + self.embeddings = nn.Embedding( + config.vocab_size, + config.dim, + ) + self.codebook_embeddings = nn.Embedding( + config.codebook_size * config.num_codebooks, + config.dim, + ) + self.layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + if self.config.tie_word_embeddings is False: + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + config.max_seq_len, + config.dim // config.n_head, + config.rope_base, + ), + persistent=False, + ) + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + config.max_seq_len, + config.max_seq_len, + dtype=torch.bool, + ) + ), + persistent=False, + ) + + # For kv cache + self.max_batch_size = -1 + self.max_seq_len = -1 + + if init_weights: + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: + return + + head_dim = self.config.dim // self.config.n_head + max_seq_len = find_multiple(max_seq_len, 8) + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_len, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def embed(self, x: Tensor) -> Tensor: + vocab_embeds = [self.embeddings(x[:, 0])] + for i in range(self.config.num_codebooks): + emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) + emb[x[:, 0] != self.semantic_token_id] = 0 + vocab_embeds.append(emb) + + x = torch.stack(vocab_embeds, dim=3) + x = x.sum(dim=3) + + return x + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> BaseTransformerForwardResult: + seq_len = inp.size(2) + + # Here we want to merge the embeddings of the codebooks + x = self.embed(inp) + + freqs_cis = self.freqs_cis[:seq_len] + + # Not that the causal mask here follows the definition of scaled_dot_product_attention + # That is, FALSE means masked out + # To maintain consistency, key_padding_mask use TRUE to mask out + mask = None + if key_padding_mask is not None: + mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) + mask = mask & key_padding_mask[:, None, None, :].logical_not() + + for layer in self.layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) + else: + x = layer(x, freqs_cis, mask) + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def forward_generate( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + return_all: bool = False, + ) -> BaseTransformerForwardResult: + # This is used for generation, optimized for torch compile + assert ( + self.max_seq_len != -1 and self.max_batch_size != -1 + ), "Please call setup_caches before forward_generate" + + x = self.embed(x) + + mask = self.causal_mask[ + None, None, input_pos, : self.max_seq_len + ] # (B, N, Q, K) + freqs_cis = self.freqs_cis[input_pos] + + for layer in self.layers: + x = layer(x, freqs_cis, mask, input_pos=input_pos) + + # If prefill, we only calculate the logits of last token + if x.size(1) > 1 and not return_all: + x = x[:, -1:] + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @staticmethod + def from_pretrained( + path: str, + load_weights: bool = False, + max_length: int | None = None, + lora_config: LoraConfig | None = None, + rope_base: int | None = None, + ) -> "BaseTransformer": + config = BaseModelArgs.from_pretrained(str(path)) + if max_length is not None: + config.max_seq_len = max_length + log.info(f"Override max_seq_len to {max_length}") + + if rope_base is not None: + config.rope_base = rope_base + log.info(f"Override rope_base to {rope_base}") + + match config.model_type: + case "naive": + model_cls = NaiveTransformer + case "dual_ar": + model_cls = DualARTransformer + case _: + raise ValueError(f"Unknown model type: {config.model_type}") + + tokenizer = AutoTokenizer.from_pretrained(str(path)) + log.info(f"Loading model from {path}, config: {config}") + model = model_cls(config, tokenizer=tokenizer) + + if lora_config is not None: + setup_lora(model, lora_config) + log.info(f"LoRA setup: {lora_config}") + + if load_weights is False: + log.info("Randomly initialized model") + else: + + if "int8" in str(Path(path)): + logger.info("Using int8 weight-only quantization!") + from tools.llama.quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(Path(path)): + logger.info("Using int4 quantization!") + path_comps = path.name.split("-") + assert path_comps[-2].startswith("g") + groupsize = int(path_comps[-2][1:]) + from tools.llama.quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + weights = torch.load( + Path(path) / "model.pth", map_location="cpu", mmap=True + ) + + if "state_dict" in weights: + logger.warning( + "Using a TextToSemantic LightningModule checkpoint, " + "please make sure it is a full model, not a LoRA model." + ) + weights = weights["state_dict"] + + if next(iter(weights.keys())).startswith("model."): + logger.info( + f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" + ) + new_weights = OrderedDict() + for k, v in weights.items(): + new_weights[k.replace("model.", "")] = v + weights = new_weights + + # Verify the name and shape of parameters since strict=False in load_state_dict. + for k, v in model.named_parameters(): + if k not in weights: + logger.warning(f"No weight for {k}") + elif v.shape != weights[k].shape: + logger.warning( + f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" + ) + + err = model.load_state_dict(weights, strict=False, assign=True) + log.info(f"Loaded weights with error: {err}") + + return model + + def save_pretrained(self, path: str, drop_lora: bool = False): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + self.config.save(path / "config.json") + state_dict = self.state_dict() + + if drop_lora: + for key in list(state_dict.keys()): + if "lora" not in key: + continue + + state_dict.pop(key) + log.info(f"Drop LoRA parameter: {key}") + + torch.save(state_dict, path / "model.pth") + self.tokenizer.save_pretrained(path) + + +class NaiveTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.codebook_output = nn.Linear( + config.dim, + config.codebook_size * config.num_codebooks, + bias=False, + ) + + self.apply(self._init_weights) + + def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: + token_logits = result.logits + x = result.hidden_states + + # Codebook + codebook_logits = self.codebook_output(self.codebook_norm(x)) + codebook_logits = rearrange( + codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + result = super().forward( + inp=inp, + key_padding_mask=key_padding_mask, + ) + return self.decode(result) + + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + result = super().forward_generate(x, input_pos) + return self.decode(result) + + +class DualARTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + # Fast transformer + self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) + + # The equivalent bs is so large that sdpa doesn't work + self.fast_layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) + ) + self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.fast_output = nn.Linear( + config.dim, + config.codebook_size, + bias=False, + ) + + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + super().setup_caches(max_batch_size, max_seq_len, dtype) + + head_dim = self.config.dim // self.config.n_head + + # Fast transformer + # The max seq len here is the number of codebooks + for b in self.fast_layers: + b.attention.kv_cache = KVCache( + max_batch_size, + self.config.num_codebooks, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + parent_result = super().forward(inp, key_padding_mask) + token_logits = parent_result.logits + x = parent_result.hidden_states + + # Fast transformer + fast_seq_len = self.config.num_codebooks + fast_mask = self.causal_mask[ + None, None, :fast_seq_len, :fast_seq_len + ] # (B, N, Q, K) + fast_freqs_cis = self.freqs_cis[:fast_seq_len] + + # Drop the last token and rotate left + codebooks = inp[:, 1:-1, 1:] + codebooks = F.pad(codebooks, (0, 1), value=0) + codebook_embeddings = self.fast_embeddings(codebooks) + x = torch.cat([x[:, None], codebook_embeddings], dim=1) + b, s = x.size(0), x.size(2) + x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len + + # Remove padded part + codebooks = rearrange(codebooks, "b n s -> (b s) n") + codebook_mask = (codebooks == 0).all(dim=-1) + + if torch.all(codebook_mask): + # If all codebooks are padded, we keep first 8 to make sure the model runs + codebook_mask[:8] = False + + x_bs, x_len = x.size(0), x.size(1) + x = x[~codebook_mask] + + for layer in self.fast_layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) + else: + x = layer(x, fast_freqs_cis, fast_mask) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) + codebook_logits = self.fast_output(fast_out) + + # Re-pad the codebook_logits + buffer = torch.zeros( + x_bs, + x_len, + codebook_logits.size(-1), + device=codebook_logits.device, + dtype=codebook_logits.dtype, + ) + buffer[~codebook_mask] = codebook_logits + codebook_logits = buffer + + assert codebook_logits.shape[1] == self.config.num_codebooks + codebook_logits = rearrange( + codebook_logits, + "(b s) n d -> b s n d", + b=b, + s=s, + n=self.config.num_codebooks, + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward_generate_fast( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> Tensor: + # Fast transformer + x = x.view(1, 1, -1) + + fast_mask = self.causal_mask[ + None, None, input_pos, : self.config.num_codebooks + ] # (B, N, Q, K) + fast_freqs_cis = self.freqs_cis[input_pos] + + for layer in self.fast_layers: + x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) # only take the last token + codebook_logits = self.fast_output(fast_out) + + return codebook_logits + + +class TransformerBlock(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: + super().__init__() + self.attention = Attention(config, use_sdpa=use_sdpa) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear( + config.dim, total_head_dim, bias=config.attention_qkv_bias + ) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.dropout = config.dropout + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.use_sdpa = use_sdpa + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.use_sdpa: + if mask is None: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + # No third party attn_mask here to use flash_attention + ) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + else: + y = self.eq_scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + return self.wo(y) + + def eq_scaled_dot_product_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + ) -> torch.Tensor: + # This is a standard scaled dot product attention + # It's low efficient, but it doesn't raise cuda error + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +class FeedForward(nn.Module): + def __init__(self, config: BaseModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938 --- /dev/null +++ b/fish_speech/models/text2semantic/lora.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass + +import loralib as lora + + +@dataclass +class LoraConfig: + r: int + lora_alpha: float + lora_dropout: float = 0.0 + + +def setup_lora(model, lora_config): + # Replace the embedding layer with a LoRA layer + model.embeddings = lora.Embedding( + num_embeddings=model.embeddings.num_embeddings, + embedding_dim=model.embeddings.embedding_dim, + padding_idx=model.embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + model.codebook_embeddings = lora.Embedding( + num_embeddings=model.codebook_embeddings.num_embeddings, + embedding_dim=model.codebook_embeddings.embedding_dim, + padding_idx=model.codebook_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Replace output layer with a LoRA layer + linears = [(model, "output")] + + # Replace all linear layers with LoRA layers + for layer in model.layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + if hasattr(model, "fast_layers"): + model.fast_embeddings = lora.Embedding( + num_embeddings=model.fast_embeddings.num_embeddings, + embedding_dim=model.fast_embeddings.embedding_dim, + padding_idx=model.fast_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Dual-AR model + linears.append((model, "fast_output")) + + for layer in model.fast_layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + for module, layer in linears: + updated_linear = lora.Linear( + in_features=getattr(module, layer).in_features, + out_features=getattr(module, layer).out_features, + bias=getattr(module, layer).bias, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + setattr(module, layer, updated_linear) + + # Mark only the LoRA layers as trainable + lora.mark_only_lora_as_trainable(model, bias="none") + + +def get_merged_state_dict(model): + # This line will merge the state dict of the model and the LoRA parameters + model.eval() + + # Then we need to remove the LoRA parameters from the state dict + state_dict = model.state_dict() + for name in list(state_dict.keys()): + if "lora" in name: + state_dict.pop(name) + + return state_dict diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py new file mode 100644 index 0000000000000000000000000000000000000000..aa21839b544174d5d91378c5daf8fe1b376a154a --- /dev/null +++ b/fish_speech/models/vqgan/modules/firefly.py @@ -0,0 +1,596 @@ +import math +from functools import partial +from math import prod +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from torch.utils.checkpoint import checkpoint + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv1D") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "zeros", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right + before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class FishConvNet(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1 + ): + super(FishConvNet, self).__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + + def forward(self, x): + pad = self.kernel_size - self.stride + extra_padding = get_extra_padding_for_conv1d( + x, self.kernel_size, self.stride, pad + ) + x = pad1d(x, (pad, extra_padding), mode="constant", value=0) + return self.conv(x).contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +class FishTransConvNet(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1): + super(FishTransConvNet, self).__init__() + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride=stride, dilation=dilation + ) + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + x = self.conv(x) + pad = self.kernel_size - self.stride + padding_right = math.ceil(pad) + padding_left = pad - padding_right + x = unpad1d(x, (padding_left, padding_right)) + return x.contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.silu(x) + xt = c1(xt) + xt = F.silu(xt) + xt = c2(xt) + x = xt + x + return x + + def remove_parametrizations(self): + for conv in self.convs1: + remove_parametrizations(conv, tensor_name="weight") + for conv in self.convs2: + remove_parametrizations(conv, tensor_name="weight") + + +class ParallelBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_sizes: tuple[int] = (3, 7, 11), + dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + ): + super().__init__() + + assert len(kernel_sizes) == len(dilation_sizes) + + self.blocks = nn.ModuleList() + for k, d in zip(kernel_sizes, dilation_sizes): + self.blocks.append(ResBlock1(channels, k, d)) + + def forward(self, x): + return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) + + def remove_parametrizations(self): + for block in self.blocks: + block.remove_parametrizations() + + +class HiFiGANGenerator(nn.Module): + def __init__( + self, + *, + hop_length: int = 512, + upsample_rates: tuple[int] = (8, 8, 2, 2, 2), + upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), + resblock_kernel_sizes: tuple[int] = (3, 7, 11), + resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + num_mels: int = 128, + upsample_initial_channel: int = 512, + pre_conv_kernel_size: int = 7, + post_conv_kernel_size: int = 7, + post_activation: Callable = partial(nn.SiLU, inplace=True), + ): + super().__init__() + + assert ( + prod(upsample_rates) == hop_length + ), f"hop_length must be {prod(upsample_rates)}" + + self.conv_pre = FishConvNet( + num_mels, + upsample_initial_channel, + pre_conv_kernel_size, + stride=1, + ).weight_norm() + + self.num_upsamples = len(upsample_rates) + self.num_kernels = len(resblock_kernel_sizes) + + self.noise_convs = nn.ModuleList() + self.ups = nn.ModuleList() + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + FishTransConvNet( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + stride=u, + ).weight_norm() + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.resblocks.append( + ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) + ) + + self.activation_post = post_activation() + self.conv_post = FishConvNet( + ch, 1, post_conv_kernel_size, stride=1 + ).weight_norm() + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.silu(x, inplace=True) + x = self.ups[i](x) + + if self.training and self.checkpointing: + x = checkpoint( + self.resblocks[i], + x, + use_reentrant=False, + ) + else: + x = self.resblocks[i](x) + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_parametrizations(self): + for up in self.ups: + remove_parametrizations(up, tensor_name="weight") + for block in self.resblocks: + block.remove_parametrizations() + remove_parametrizations(self.conv_pre, tensor_name="weight") + remove_parametrizations(self.conv_post, tensor_name="weight") + + +# DropPath copied from timm library +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ # noqa: E501 + + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ # noqa: E501 + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None] * x + self.bias[:, None] + return x + + +# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + kernel_size (int): Kernel size for depthwise conv. Default: 7. + dilation (int): Dilation for depthwise conv. Default: 1. + """ # noqa: E501 + + def __init__( + self, + dim: int, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6, + mlp_ratio: float = 4.0, + kernel_size: int = 7, + dilation: int = 1, + ): + super().__init__() + + self.dwconv = FishConvNet( + dim, + dim, + kernel_size=kernel_size, + # padding=int(dilation * (kernel_size - 1) / 2), + groups=dim, + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, int(mlp_ratio * dim) + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, apply_residual: bool = True): + input = x + + x = self.dwconv(x) + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma * x + + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) + x = self.drop_path(x) + + if apply_residual: + x = input + x + + return x + + +class ConvNeXtEncoder(nn.Module): + def __init__( + self, + input_channels: int = 3, + depths: list[int] = [3, 3, 9, 3], + dims: list[int] = [96, 192, 384, 768], + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + kernel_size: int = 7, + ): + super().__init__() + assert len(depths) == len(dims) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + FishConvNet( + input_channels, + dims[0], + kernel_size=7, + # padding=3, + # padding_mode="replicate", + # padding_mode="zeros", + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + + for i in range(len(depths) - 1): + mid_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), + ) + self.downsample_layers.append(mid_layer) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + cur = 0 + for i in range(len(depths)): + stage = nn.Sequential( + *[ + ConvNeXtBlock( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + kernel_size=kernel_size, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + for i in range(len(self.downsample_layers)): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + + return self.norm(x) + + +class FireflyArchitecture(nn.Module): + def __init__( + self, + backbone: nn.Module, + head: nn.Module, + quantizer: nn.Module, + spec_transform: nn.Module, + ): + super().__init__() + + self.backbone = backbone + self.head = head + self.quantizer = quantizer + self.spec_transform = spec_transform + self.downsample_factor = math.prod(self.quantizer.downsample_factor) + + def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor: + if self.spec_transform is not None: + x = self.spec_transform(x) + + x = self.backbone(x) + if mask is not None: + x = x * mask + + if self.quantizer is not None: + vq_result = self.quantizer(x) + x = vq_result.z + + if mask is not None: + x = x * mask + + x = self.head(x, template=template) + + if x.ndim == 2: + x = x[:, None, :] + + if self.vq is not None: + return x, vq_result + + return x + + def encode(self, audios, audio_lengths): + audios = audios.float() + + mels = self.spec_transform(audios) + mel_lengths = audio_lengths // self.spec_transform.hop_length + mel_masks = sequence_mask(mel_lengths, mels.shape[2]) + mel_masks_float_conv = mel_masks[:, None, :].float() + mels = mels * mel_masks_float_conv + + # Encode + encoded_features = self.backbone(mels) * mel_masks_float_conv + feature_lengths = mel_lengths // self.downsample_factor + + return self.quantizer.encode(encoded_features), feature_lengths + + def decode(self, indices, feature_lengths) -> torch.Tensor: + mel_masks = sequence_mask( + feature_lengths * self.downsample_factor, + indices.shape[2] * self.downsample_factor, + ) + mel_masks_float_conv = mel_masks[:, None, :].float() + audio_lengths = ( + feature_lengths * self.downsample_factor * self.spec_transform.hop_length + ) + + audio_masks = sequence_mask( + audio_lengths, + indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length, + ) + audio_masks_float_conv = audio_masks[:, None, :].float() + + z = self.quantizer.decode(indices) * mel_masks_float_conv + x = self.head(z) * audio_masks_float_conv + + return x, audio_lengths + + def remove_parametrizations(self): + if hasattr(self.backbone, "remove_parametrizations"): + self.backbone.remove_parametrizations() + + if hasattr(self.head, "remove_parametrizations"): + self.head.remove_parametrizations() + + @property + def device(self): + return next(self.parameters()).device diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea4853376b6e663404ff48d6c6b5f664dde4094 --- /dev/null +++ b/fish_speech/models/vqgan/modules/fsq.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from vector_quantize_pytorch import GroupedResidualFSQ + +from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet + + +@dataclass +class FSQResult: + z: torch.Tensor + codes: torch.Tensor + latents: torch.Tensor + + +class DownsampleFiniteScalarQuantize(nn.Module): + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + n_groups: int = 1, + levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 + downsample_factor: tuple[int] = (2, 2), + downsample_dims: tuple[int] | None = None, + ): + super().__init__() + + if downsample_dims is None: + downsample_dims = [input_dim for _ in range(len(downsample_factor))] + + all_dims = (input_dim,) + tuple(downsample_dims) + + self.residual_fsq = GroupedResidualFSQ( + dim=all_dims[-1], + levels=levels, + num_quantizers=n_codebooks, + groups=n_groups, + ) + + self.downsample_factor = downsample_factor + self.downsample_dims = downsample_dims + + self.downsample = nn.Sequential( + *[ + nn.Sequential( + FishConvNet( + all_dims[idx], + all_dims[idx + 1], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx + 1]), + ) + for idx, factor in enumerate(downsample_factor) + ] + ) + + self.upsample = nn.Sequential( + *[ + nn.Sequential( + FishTransConvNet( + all_dims[idx + 1], + all_dims[idx], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx]), + ) + for idx, factor in reversed(list(enumerate(downsample_factor))) + ] + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, z) -> FSQResult: + original_shape = z.shape + z = self.downsample(z) + quantized, indices = self.residual_fsq(z.mT) + result = FSQResult( + z=quantized.mT, + codes=indices.mT, + latents=z, + ) + result.z = self.upsample(result.z) + + # Pad or crop z to match original shape + diff = original_shape[-1] - result.z.shape[-1] + left = diff // 2 + right = diff - left + + if diff > 0: + result.z = F.pad(result.z, (left, right)) + elif diff < 0: + result.z = result.z[..., left:-right] + + return result + + def encode(self, z): + z = self.downsample(z) + _, indices = self.residual_fsq(z.mT) + indices = rearrange(indices, "g b l r -> b (g r) l") + return indices + + def decode(self, indices: torch.Tensor): + indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) + z_q = self.residual_fsq.get_output_from_indices(indices) + z_q = self.upsample(z_q.mT) + return z_q diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac --- /dev/null +++ b/fish_speech/models/vqgan/utils.py @@ -0,0 +1,94 @@ +import matplotlib +import torch +from matplotlib import pyplot as plt + +matplotlib.use("Agg") + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def plot_mel(data, titles=None): + fig, axes = plt.subplots(len(data), 1, squeeze=False) + + if titles is None: + titles = [None for i in range(len(data))] + + plt.tight_layout() + + for i in range(len(data)): + mel = data[i] + + if isinstance(mel, torch.Tensor): + mel = mel.float().detach().cpu().numpy() + + axes[i][0].imshow(mel, origin="lower") + axes[i][0].set_aspect(2.5, adjustable="box") + axes[i][0].set_ylim(0, mel.shape[0]) + axes[i][0].set_title(titles[i], fontsize="medium") + axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) + axes[i][0].set_anchor("W") + + return fig + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(in_act, n_channels): + n_channels_int = n_channels[0] + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + + return acts + + +def avg_with_mask(x, mask): + assert mask.dtype == torch.float, "Mask should be float" + + if mask.ndim == 2: + mask = mask.unsqueeze(1) + + if mask.shape[1] == 1: + mask = mask.expand_as(x) + + return (x * mask).sum() / mask.sum() diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29 --- /dev/null +++ b/fish_speech/scheduler.py @@ -0,0 +1,40 @@ +import math + + +def get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int, + num_cycles: float = 0.5, + final_lr_ratio: float = 0.0, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + + return max( + final_lr_ratio, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + +def get_constant_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int | None = None, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + return 1.0 diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce --- /dev/null +++ b/fish_speech/text/__init__.py @@ -0,0 +1,4 @@ +from .clean import clean_text +from .spliter import split_text + +__all__ = ["clean_text", "split_text"] diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89 --- /dev/null +++ b/fish_speech/text/chn_text_norm/.gitignore @@ -0,0 +1,114 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# JetBrains PyCharm +.idea + +# Customize +references +url.txt + +# Git +.git diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6 --- /dev/null +++ b/fish_speech/text/chn_text_norm/README.md @@ -0,0 +1,36 @@ +# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. + +# Chn Text Norm + +this is a repository for chinese text normalization (no longer maintained). + +## Quick Start ## + +### Git Clone Repo ### + +git clone this repo to the root directory of your project which need to use it. + + cd /path/to/proj + git clone https://github.com/Joee1995/chn-text-norm.git + +after that, your doc tree should be: +``` +proj # root of your project +|--- chn_text_norm # this chn-text-norm tool + |--- text.py + |--- ... +|--- text_normalize.py # your text normalization code +|--- ... +``` + +### How to Use ? ### + + # text_normalize.py + from chn_text_norm.text import * + + raw_text = 'your raw text' + text = Text(raw_text=raw_text).normalize() + +### How to add quantums ### + +打开test.py,然后你就知道怎么做了。 diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py new file mode 100644 index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761 --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_class.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +"""基本类 +中文字符类 +中文数字/数位类 +中文数字类 +中文数位类 +中文数字系统类 +中文数学符号类 +*中文其他符号类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES + + +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit( + power=index + 1, + simplified=value[0], + traditional=value[1], + big_s=value[1], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + else: + raise ValueError( + "Counting type should be in {0} ({1} provided).".format( + NUMBERING_TYPES, numbering_type + ) + ) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__( + self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None + ): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_constant.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +"""基本常量 +中文数字/数位/符号字符常量 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_util.py @@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +"""基本方法 +创建中文数字系统 方法 +中文字符串 <=> 数字串 方法 +数字串 <=> 中文字符串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_class import * +from fish_speech.text.chn_text_norm.basic_constant import * + + +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + larger_units = [ + CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) + ] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + smaller_units = [ + CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) + ] + # digis + chinese_digis = zip( + CHINESE_DIGIS, + CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, + BIG_CHINESE_DIGIS_TRADITIONAL, + ) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [ + d.traditional, + d.simplified, + d.big_s, + d.big_t, + d.alt_s, + d.alt_t, + ]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [ + get_symbol(c, system) for c in dec_string + ] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance( + integer_symbols[-2], CNU + ): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None) + ) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if ( + isinstance(result[-i - 1], CNU) + and result[-i - 1].power < current_unit.power + ): + result[-i - 1] = CNU( + result[-i - 1].power + current_unit.power, + None, + None, + None, + None, + ) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next( + u for u in reversed(system.units) if u.power < len(striped_string) + ) + result_string = value_string[: -result_unit.power] + return ( + get_value(result_string) + + [result_unit] + + get_value(striped_string[-result_unit.power :]) + ) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string) + ) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND( + 2, + system.digits[2].alt_s, + system.digits[2].alt_t, + system.digits[2].big_s, + system.digits[2].big_t, + ) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = ( + result_symbols[i + 1] if i < len(result_symbols) - 1 else None + ) + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance( + previous_symbol, (CNU, type(None)) + ): + if next_symbol.power != 1 and ( + (previous_symbol is None) or (previous_symbol.power != 1) + ): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s + ) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s + ) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] + in [ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], + ] + and result[0] + in [ + CHINESE_DIGIS[1], + BIG_CHINESE_DIGIS_SIMPLIFIED[1], + BIG_CHINESE_DIGIS_TRADITIONAL[1], + ] + ): + result = result[1:] + + return result + + +if __name__ == "__main__": + + # 测试程序 + all_chinese_number_string = ( + CHINESE_DIGIS + + BIG_CHINESE_DIGIS_SIMPLIFIED + + BIG_CHINESE_DIGIS_TRADITIONAL + + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL + + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL + + ZERO_ALT + + ONE_ALT + + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) + ) + + print("num:", chn2num("一万零四百零三点八零五")) + print("num:", chn2num("一亿六点三")) + print("num:", chn2num("一亿零六点三")) + print("num:", chn2num("两千零一亿六点三")) + # print('num:', chn2num('一零零八六')) + print("txt:", num2chn("10260.03", alt_zero=True)) + print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) + print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) + print( + "txt:", + num2chn( + "059523810880", + alt_one=True, + alt_two=False, + use_lzeros=True, + use_rzeros=True, + use_units=False, + ), + ) + + print(all_chinese_number_string) diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py new file mode 100644 index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616 --- /dev/null +++ b/fish_speech/text/chn_text_norm/cardinal.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""CARDINAL类 (包含小数DECIMAL类) +纯数 <=> 中文字符串 方法 +中文字符串 <=> 纯数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +if __name__ == "__main__": + + # 测试程序 + print(Cardinal(cardinal="21357.230").cardinal2chntext()) diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py new file mode 100644 index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e --- /dev/null +++ b/fish_speech/text/chn_text_norm/date.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +"""DATE类 +日期 <=> 中文字符串 方法 +中文字符串 <=> 日期 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-07" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.digit import Digit + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", maxsplit=1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", maxsplit=1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Date(date="09年3月16日").date2chntext()) diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py new file mode 100644 index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185 --- /dev/null +++ b/fish_speech/text/chn_text_norm/digit.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""DIGIT类 +数字串 <=> 中文字符串 方法 +中文字符串 <=> 数字串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +if __name__ == "__main__": + + # 测试程序 + print(Digit(digit="2016").digit2chntext()) diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103 --- /dev/null +++ b/fish_speech/text/chn_text_norm/fraction.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +"""FRACTION类 +分数 <=> 中文字符串 方法 +中文字符串 <=> 分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +if __name__ == "__main__": + + # 测试程序 + print(Fraction(fraction="2135/7230").fraction2chntext()) + print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a --- /dev/null +++ b/fish_speech/text/chn_text_norm/money.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +"""MONEY类 +金钱 <=> 中文字符串 方法 +中文字符串 <=> 金钱 方法 +""" +import re + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-08" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() + ) + self.chntext = money + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Money(money="21.5万元").money2chntext()) + print(Money(money="230块5毛").money2chntext()) diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py new file mode 100644 index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0 --- /dev/null +++ b/fish_speech/text/chn_text_norm/percentage.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +"""PERCENTAGE类 +百分数 <=> 中文字符串 方法 +中文字符串 <=> 百分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-06" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +if __name__ == "__main__": + + # 测试程序 + print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) + print(Percentage(percentage="65.3%").percentage2chntext()) diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff --- /dev/null +++ b/fish_speech/text/chn_text_norm/telephone.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +"""TELEPHONE类 +电话号码 <=> 中文字符串 方法 +中文字符串 <=> 电话号码 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +if __name__ == "__main__": + + # 测试程序 + print(TelePhone(telephone="0595-23980880").telephone2chntext()) + # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31 --- /dev/null +++ b/fish_speech/text/chn_text_norm/text.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +TEXT类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +import re + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.date import Date +from fish_speech.text.chn_text_norm.digit import Digit +from fish_speech.text.chn_text_norm.fraction import Fraction +from fish_speech.text.chn_text_norm.money import Money +from fish_speech.text.chn_text_norm.percentage import Percentage +from fish_speech.text.chn_text_norm.telephone import TelePhone + +CURRENCY_NAMES = ( + "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" + "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +) +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)" +) + + +class Text: + """ + Text类 + """ + + def __init__(self, raw_text, norm_text=None): + self.raw_text = "^" + raw_text + "$" + self.norm_text = norm_text + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile( + r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile( + r"\D+((\d+(\.\d+)?)[多余几]?" + + CURRENCY_UNITS + + "(\d" + + CURRENCY_UNITS + + "?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace( + matcher[0], Money(money=matcher[0]).money2chntext(), 1 + ) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace( + matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 + ) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), + 1, + ) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace( + matcher, Fraction(fraction=matcher).fraction2chntext(), 1 + ) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace( + matcher[0], + Percentage(percentage=matcher[0]).percentage2chntext(), + 1, + ) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip("^").rstrip("$") + + +if __name__ == "__main__": + + # 测试程序 + print(Text(raw_text="固话:0595-23865596或23880880。").normalize()) + print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize()) + print(Text(raw_text="分数:32477/76391。").normalize()) + print(Text(raw_text="百分数:80.03%。").normalize()) + print(Text(raw_text="编号:31520181154418。").normalize()) + print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize()) + print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize()) + print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize()) + print(Text(raw_text="特殊:O2O或B2C。").normalize()) diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py new file mode 100644 index 0000000000000000000000000000000000000000..792db5d44e40e22ea3ba7b3d9e96b01f5ed46835 --- /dev/null +++ b/fish_speech/text/clean.py @@ -0,0 +1,45 @@ +import re + +SYMBOLS_MAPPING = { + "\n": ".", + "…": ".", + "“": "'", + "”": "'", + "‘": "'", + "’": "'", + "【": "", + "】": "", + "[": "", + "]": "", + "(": "", + ")": "", + "(": "", + ")": "", + "・": "", + "·": "", + "「": "'", + "」": "'", + "《": "'", + "》": "'", + "—": "", + "~": "", + "~": "", + ":": ",", + ";": ",", + ";": ",", + ":": ",", +} + +REPLACE_SYMBOL_REGEX = re.compile( + "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) +) + + +def clean_text(text): + # Clean the text + text = text.strip() + + # Replace all chinese symbols with their english counterparts + text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + + return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py new file mode 100644 index 0000000000000000000000000000000000000000..d4bb995487c4f53818c6b2a16cf0a886b4e02e84 --- /dev/null +++ b/fish_speech/text/spliter.py @@ -0,0 +1,130 @@ +import re +import string + +from fish_speech.text.clean import clean_text + + +def utf_8_len(text): + return len(text.encode("utf-8")) + + +def break_text(texts, length, splits: set): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if char in splits: + yield curr + curr = "" + + if curr: + yield curr + + +def break_text_by_length(texts, length): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if utf_8_len(curr) >= length: + yield curr + curr = "" + + if curr: + yield curr + + +def add_cleaned(curr, segments): + curr = curr.strip() + if curr and not all(c.isspace() or c in string.punctuation for c in curr): + segments.append(curr) + + +def protect_float(text): + # Turns 3.14 into <3_f_14> to prevent splitting + return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) + + +def unprotect_float(text): + # Turns <3_f_14> into 3.14 + return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) + + +def split_text(text, length): + text = clean_text(text) + + # Break the text into pieces with following rules: + # 1. Split the text at ".", "!", "?" if text is NOT a float + # 2. If the text is longer than length, split at "," + # 3. If the text is still longer than length, split at " " + # 4. If the text is still longer than length, split at any character to length + + texts = [text] + texts = map(protect_float, texts) + texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) + texts = map(unprotect_float, texts) + texts = break_text(texts, length, {",", ","}) + texts = break_text(texts, length, {" "}) + texts = list(break_text_by_length(texts, length)) + + # Then, merge the texts into segments with length <= length + segments = [] + curr = "" + + for text in texts: + if utf_8_len(curr) + utf_8_len(text) <= length: + curr += text + else: + add_cleaned(curr, segments) + curr = text + + if curr: + add_cleaned(curr, segments) + + return segments + + +if __name__ == "__main__": + # Test the split_text function + + text = "This is a test sentence. This is another test sentence. And a third one." + + assert split_text(text, 50) == [ + "This is a test sentence.", + "This is another test sentence. And a third one.", + ] + assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] + assert split_text(" ", 10) == [] + assert split_text("a", 10) == ["a"] + + text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." + assert split_text(text, 50) == [ + "This is a test sentence with only commas,", + "and no dots, and no exclamation marks,", + "and no question marks, and no newlines.", + ] + + text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." + # First half split at " ", second half split at "," + assert split_text(text, 50) == [ + "This is a test sentence This is a test sentence", + "This is a test sentence. This is a test sentence,", + "This is a test sentence, This is a test sentence.", + ] + + text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" + assert split_text(text, 50) == [ + "这是一段很长的中文文本,", + "而且没有句号,也没有感叹号,", + "也没有问号,也没有换行符.", + ] diff --git a/fish_speech/train.py b/fish_speech/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e693f3adc4dda787bdd587aec29f53355f2b1653 --- /dev/null +++ b/fish_speech/train.py @@ -0,0 +1,141 @@ +import os + +os.environ["USE_LIBUV"] = "0" +import sys +from typing import Optional + +import hydra +import lightning as L +import pyrootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies import DDPStrategy +from omegaconf import DictConfig, OmegaConf + +os.environ.pop("SLURM_NTASKS", None) +os.environ.pop("SLURM_JOB_NAME", None) +os.environ.pop("SLURM_NTASKS_PER_NODE", None) + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# Allow TF32 on Ampere GPUs +torch.set_float32_matmul_precision("high") +torch.backends.cudnn.allow_tf32 = True + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + +import fish_speech.utils as utils + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + Args: + cfg (DictConfig): Configuration composed by Hydra. + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ # noqa: E501 + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=False) + + if cfg.get("deterministic"): + torch.use_deterministic_algorithms(True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + + ckpt_path = cfg.get("ckpt_path") + auto_resume = False + + resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) + if resume_ckpt_path is not None: + ckpt_path = resume_ckpt_path + auto_resume = True + + if ckpt_path is not None: + log.info(f"Resuming from checkpoint: {ckpt_path}") + + # resume weights only is disabled for auto-resume + if cfg.get("resume_weights_only") and auto_resume is False: + log.info("Resuming weights only!") + ckpt = torch.load(ckpt_path, map_location=model.device) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + err = model.load_state_dict(ckpt, strict=False) + log.info(f"Error loading state dict: {err}") + ckpt_path = None + + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = cfg.get("ckpt_path") + + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main( + version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" +) +def main(cfg: DictConfig) -> Optional[float]: + # train the model + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05378519dbd18361c639e33413d011e7307c9adb --- /dev/null +++ b/fish_speech/utils/__init__.py @@ -0,0 +1,23 @@ +from .braceexpand import braceexpand +from .context import autocast_exclude_mps +from .file import get_latest_checkpoint +from .instantiators import instantiate_callbacks, instantiate_loggers +from .logger import RankedLogger +from .logging_utils import log_hyperparameters +from .rich_utils import enforce_tags, print_config_tree +from .utils import extras, get_metric_value, task_wrapper + +__all__ = [ + "enforce_tags", + "extras", + "get_metric_value", + "RankedLogger", + "instantiate_callbacks", + "instantiate_loggers", + "log_hyperparameters", + "print_config_tree", + "task_wrapper", + "braceexpand", + "get_latest_checkpoint", + "autocast_exclude_mps", +] diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974 --- /dev/null +++ b/fish_speech/utils/braceexpand.py @@ -0,0 +1,217 @@ +""" +Bash-style brace expansion +Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py +License: MIT +""" + +import re +import string +from itertools import chain, product +from typing import Iterable, Iterator, Optional + +__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] + + +class UnbalancedBracesError(ValueError): + pass + + +alphabet = string.ascii_uppercase + string.ascii_lowercase + +int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") +char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") +escape_re = re.compile(r"\\(.)") + + +def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: + """braceexpand(pattern) -> iterator over generated strings + + Returns an iterator over the strings resulting from brace expansion + of pattern. This function implements Brace Expansion as described in + bash(1), with the following limitations: + + * A pattern containing unbalanced braces will raise an + UnbalancedBracesError exception. In bash, unbalanced braces will either + be partly expanded or ignored. + + * A mixed-case character range like '{Z..a}' or '{a..Z}' will not + include the characters '[]^_`' between 'Z' and 'a'. + + When escape is True (the default), characters in pattern can be + prefixed with a backslash to cause them not to be interpreted as + special characters for brace expansion (such as '{', '}', ','). + To pass through a a literal backslash, double it ('\\\\'). + + When escape is False, backslashes in pattern have no special + meaning and will be preserved in the output. + + Examples: + + >>> from braceexpand import braceexpand + + # Integer range + >>> list(braceexpand('item{1..3}')) + ['item1', 'item2', 'item3'] + + # Character range + >>> list(braceexpand('{a..c}')) + ['a', 'b', 'c'] + + # Sequence + >>> list(braceexpand('index.html{,.backup}')) + ['index.html', 'index.html.backup'] + + # Nested patterns + >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) + ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] + + # Prefixing an integer with zero causes all numbers to be padded to + # the same width. + >>> list(braceexpand('{07..10}')) + ['07', '08', '09', '10'] + + # An optional increment can be specified for ranges. + >>> list(braceexpand('{a..g..2}')) + ['a', 'c', 'e', 'g'] + + # Ranges can go in both directions. + >>> list(braceexpand('{4..1}')) + ['4', '3', '2', '1'] + + # Numbers can be negative + >>> list(braceexpand('{2..-1}')) + ['2', '1', '0', '-1'] + + # Unbalanced braces raise an exception. + >>> list(braceexpand('{1{2,3}')) + Traceback (most recent call last): + ... + UnbalancedBracesError: Unbalanced braces: '{1{2,3}' + + # By default, the backslash is the escape character. + >>> list(braceexpand(r'{1\\{2,3}')) + ['1{2', '3'] + + # Setting 'escape' to False disables backslash escaping. + >>> list(braceexpand(r'\\{1,2}', escape=False)) + ['\\\\1', '\\\\2'] + + """ + return ( + escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) + ) + + +def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'pattern:', pattern + while pos < len(pattern): + if escape and pattern[pos] == "\\": + pos += 2 + continue + elif pattern[pos] == "{": + if bracketdepth == 0 and pos > start: + # print 'literal:', pattern[start:pos] + items.append([pattern[start:pos]]) + start = pos + bracketdepth += 1 + elif pattern[pos] == "}": + bracketdepth -= 1 + if bracketdepth == 0: + # print 'expression:', pattern[start+1:pos] + expr = pattern[start + 1 : pos] + item = parse_expression(expr, escape) + if item is None: # not a range or sequence + items.extend([["{"], parse_pattern(expr, escape), ["}"]]) + else: + items.append(item) + start = pos + 1 # skip the closing brace + pos += 1 + + if bracketdepth != 0: # unbalanced braces + raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) + + if start < pos: + items.append([pattern[start:]]) + + return ("".join(item) for item in product(*items)) + + +def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: + int_range_match = int_range_re.match(expr) + if int_range_match: + return make_int_range(*int_range_match.groups()) + + char_range_match = char_range_re.match(expr) + if char_range_match: + return make_char_range(*char_range_match.groups()) + + return parse_sequence(expr, escape) + + +def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: + # sequence -> chain(*sequence_items) + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'sequence:', seq + while pos < len(seq): + if escape and seq[pos] == "\\": + pos += 2 + continue + elif seq[pos] == "{": + bracketdepth += 1 + elif seq[pos] == "}": + bracketdepth -= 1 + elif seq[pos] == "," and bracketdepth == 0: + items.append(parse_pattern(seq[start:pos], escape)) + start = pos + 1 # skip the comma + pos += 1 + + if bracketdepth != 0: + raise UnbalancedBracesError + if not items: + return None + + # part after the last comma (may be the empty string) + items.append(parse_pattern(seq[start:], escape)) + return chain(*items) + + +def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: + if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): + padding = max(len(left), len(right)) + else: + padding = 0 + step = (int(incr) or 1) if incr else 1 + start = int(left) + end = int(right) + r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) + fmt = "%0{}d".format(padding) + return (fmt % i for i in r) + + +def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: + step = (int(incr) or 1) if incr else 1 + start = alphabet.index(left) + end = alphabet.index(right) + if start < end: + return alphabet[start : end + 1 : step] + else: + end = end or -len(alphabet) + return alphabet[start : end - 1 : -step] + + +if __name__ == "__main__": + import doctest + import sys + + failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) + if failed: + sys.exit(1) diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py new file mode 100644 index 0000000000000000000000000000000000000000..f04a99290ab32f7fe5b60656075a2d03af8468d6 --- /dev/null +++ b/fish_speech/utils/context.py @@ -0,0 +1,13 @@ +from contextlib import nullcontext + +import torch + + +def autocast_exclude_mps( + device_type: str, dtype: torch.dtype +) -> nullcontext | torch.autocast: + return ( + nullcontext() + if torch.backends.mps.is_available() + else torch.autocast(device_type, dtype) + ) diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..78c82640a963fa556657107729f7543d2e7c3510 --- /dev/null +++ b/fish_speech/utils/file.py @@ -0,0 +1,16 @@ +import os +from pathlib import Path + + +def get_latest_checkpoint(path: Path | str) -> Path | None: + # Find the latest checkpoint + ckpt_dir = Path(path) + + if ckpt_dir.exists() is False: + return None + + ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) + if len(ckpts) == 0: + return None + + return ckpts[-1] diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e --- /dev/null +++ b/fish_speech/utils/instantiators.py @@ -0,0 +1,50 @@ +from typing import List + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger + +from .logger import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc --- /dev/null +++ b/fish_speech/utils/logger.py @@ -0,0 +1,55 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = True, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429 --- /dev/null +++ b/fish_speech/utils/logging_utils.py @@ -0,0 +1,48 @@ +from lightning.pytorch.utilities import rank_zero_only + +from fish_speech.utils import logger as log + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f --- /dev/null +++ b/fish_speech/utils/rich_utils.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from fish_speech.utils import logger as log + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ # noqa: E501 + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. " + + f"Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841 --- /dev/null +++ b/fish_speech/utils/spectrogram.py @@ -0,0 +1,122 @@ +import torch +import torchaudio.functional as F +from torch import Tensor, nn +from torchaudio.transforms import MelScale + + +class LinearSpectrogram(nn.Module): + def __init__( + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", + ): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.mode = mode + + self.register_buffer("window", torch.hann_window(win_length), persistent=False) + + def forward(self, y: Tensor) -> Tensor: + if y.ndim == 3: + y = y.squeeze(1) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + (self.win_length - self.hop_length) // 2, + (self.win_length - self.hop_length + 1) // 2, + ), + mode="reflect", + ).squeeze(1) + + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + spec = torch.view_as_real(spec) + + if self.mode == "pow2_sqrt": + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + return spec + + +class LogMelSpectrogram(nn.Module): + def __init__( + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max or float(sample_rate // 2) + + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + + fb = F.melscale_fbanks( + n_freqs=self.n_fft // 2 + 1, + f_min=self.f_min, + f_max=self.f_max, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer( + "fb", + fb, + persistent=False, + ) + + def compress(self, x: Tensor) -> Tensor: + return torch.log(torch.clamp(x, min=1e-5)) + + def decompress(self, x: Tensor) -> Tensor: + return torch.exp(x) + + def apply_mel_scale(self, x: Tensor) -> Tensor: + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) + + def forward( + self, x: Tensor, return_linear: bool = False, sample_rate: int = None + ) -> Tensor: + if sample_rate is not None and sample_rate != self.sample_rate: + x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) + + linear = self.spectrogram(x) + x = self.apply_mel_scale(linear) + x = self.compress(x) + + if return_linear: + return x, self.compress(linear) + + return x diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c546bfa1eddd2ac6bf484cce1ec06da1d33fb121 --- /dev/null +++ b/fish_speech/utils/utils.py @@ -0,0 +1,114 @@ +import warnings +from importlib.util import find_spec +from typing import Callable + +from omegaconf import DictConfig + +from .logger import RankedLogger +from .rich_utils import enforce_tags, print_config_tree + +log = RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: + + ... + + return metric_dict, object_dict + ``` + """ # noqa: E501 + + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or + # cause out-of-memory errors so when using hparam search + # plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.run_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css new file mode 100644 index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7 --- /dev/null +++ b/fish_speech/webui/css/style.css @@ -0,0 +1,161 @@ +:root { + --my-200: #80eeee; + --my-50: #ecfdf5; + --water-width: 300px; + --water-heigh: 300px; +} + + +/* general styled components */ +.tools { + align-items: center; + justify-content: center; +} + +.gradio-button { + max-width: 2.2em; + min-width: 2.2em !important; + height: 2.4em; + align-self: end; + line-height: 1em; + border-radius: 0.5em; + +} + +.gradio-button.secondary-down, .gradio-button.secondary-down:hover{ + box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; +} + +/* replace original footer with ours */ +a{ + font-weight: bold; + cursor: pointer; + color: #030C14 !important; +} + +footer { + display: none !important; +} + +#footer{ + text-align: center; +} + +#footer div{ + display: inline-block; +} + +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + +/*@keyframes moveBackground {*/ +/* 0% {*/ +/* background-position: 0 0;*/ +/* }*/ +/* 100% {*/ +/* background-position: -100px 100px;*/ +/* }*/ +/*}*/ +@keyframes moveJellyBackground { + 0% { + background-position: 0% 50%; + } + 50% { + background-position: 100% 50%; + } + 100% { + background-position: 0% 50%; + } +} + +.gradio-container { + position: absolute; + z-index: 10; +} + + +.quan { + position: absolute; + bottom: 0; + width: var(--water-width); + height: var(--water-heigh); + border-radius: 0; + /*border: 3px solid rgb(246, 247, 248);*/ + /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ + z-index: 0; + +} + +.quan:last-child { + margin-right: 0; +} + +.shui { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgb(23, 106, 201); + border-radius: 0; + overflow: hidden; + z-index: 0; +} + +.shui::after { + + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 40%; + background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); + animation: shi 5s linear infinite; +} + +@keyframes shi { + 0% { + transform: translate(-50%, -65%) rotate(0deg); + } + 100% { + transform: translate(-50%, -65%) rotate(360deg); + } +} + +.shui::before { + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 42%; + background-color: rgb(240, 228, 228, 0.2); + animation: xu 7s linear infinite; +} + +@keyframes xu { + 0% { + transform: translate(-50%, -60%) rotate(0deg); + } + 100% { + transform: translate(-50%, -60%) rotate(360deg); + } +} + +fieldset.data_src div.wrap label { + background: #f8bffee0 !important; +} + +.scrollable-component { + max-height: 100px; + overflow-y: auto; +} + +#file_accordion { + max-height: 220px !important; +} diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html new file mode 100644 index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615 --- /dev/null +++ b/fish_speech/webui/html/footer.html @@ -0,0 +1,11 @@ +
+ API +  •  + Github +  •  + Gradio +
+
+
+{versions} +
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js new file mode 100644 index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868 --- /dev/null +++ b/fish_speech/webui/js/animate.js @@ -0,0 +1,69 @@ + +function createGradioAnimation() { + const params = new URLSearchParams(window.location.search); + if (!params.has('__theme')) { + params.set('__theme', 'light'); + window.location.search = params.toString(); + } + + var gradioApp = document.querySelector('gradio-app'); + if (gradioApp) { + + document.documentElement.style.setProperty('--my-200', '#80eeee'); + document.documentElement.style.setProperty('--my-50', '#ecfdf5'); + + // gradioApp.style.position = 'relative'; + // gradioApp.style.backgroundSize = '200% 200%'; + // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; + // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; + // gradioApp.style.display = 'flex'; + // gradioApp.style.justifyContent = 'flex-start'; + // gradioApp.style.flexWrap = 'nowrap'; + // gradioApp.style.overflowX = 'auto'; + + // for (let i = 0; i < 6; i++) { + // var quan = document.createElement('div'); + // quan.className = 'quan'; + // gradioApp.insertBefore(quan, gradioApp.firstChild); + // quan.id = 'quan' + i.toString(); + // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; + // var quanContainer = document.querySelector('.quan'); + // if (quanContainer) { + // var shui = document.createElement('div'); + // shui.className = 'shui'; + // quanContainer.insertBefore(shui, quanContainer.firstChild) + // } + // } + } + + var container = document.createElement('div'); + container.id = 'gradio-animation'; + container.style.fontSize = '2em'; + container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; + container.style.fontWeight = 'bold'; + container.style.textAlign = 'center'; + container.style.marginBottom = '20px'; + + var text = 'Welcome to Fish-Speech!'; + for (var i = 0; i < text.length; i++) { + (function(i){ + setTimeout(function(){ + var letter = document.createElement('span'); + letter.style.opacity = '0'; + letter.style.transition = 'opacity 0.5s'; + letter.innerText = text[i]; + + container.appendChild(letter); + + setTimeout(function() { + letter.style.opacity = '1'; + }, 50); + }, i * 200); + })(i); + } + + var gradioContainer = document.querySelector('.gradio-container'); + gradioContainer.insertBefore(container, gradioContainer.firstChild); + + return 'Animation created'; +} diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f57b595a20177800dbedd71faef573ee8398418 --- /dev/null +++ b/fish_speech/webui/launch_utils.py @@ -0,0 +1,120 @@ +import importlib.util +import os +import subprocess +import sys +from functools import lru_cache +from pathlib import Path +from typing import Iterable + +import gradio as gr +from gradio.themes.base import Base +from gradio.themes.utils import colors, fonts, sizes + +GIT = ( + (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() + if sys.platform == "win32" + else "git" +) +GIT = str(GIT) + + +def is_module_installed(module_name: str) -> bool: + spec = importlib.util.find_spec(module_name) + return spec is not None + + +@lru_cache() +def commit_hash(): + try: + return subprocess.check_output( + [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" + ).strip() + except Exception: + return "" + + +def versions_html(): + import torch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = commit_hash() + hash = commit.strip("'").split(" ")[0] + + return f""" +version: {hash} + •  +python: {python_version} + •  +torch: {getattr(torch, '__long_version__',torch.__version__)} + •  +gradio: {gr.__version__} + •  +author: fishaudio +""" + + +def version_check(commit): + try: + import requests + + commits = requests.get( + "https://api.github.com/repos/fishaudio/fish-speech/branches/main" + ).json() + if commit != "" and commits["commit"]["sha"] != commit: + print("--------------------------------------------------------") + print("| You are not up to date with the most recent release. |") + print("| Consider running `git pull` to update. |") + print("--------------------------------------------------------") + elif commits["commit"]["sha"] == commit: + print("You are up to date with the most recent release.") + else: + print("Not a git clone, can't perform version check.") + except Exception as e: + print("version check failed", e) + + +class Seafoam(Base): + def __init__( + self, + *, + primary_hue: colors.Color | str = colors.emerald, + secondary_hue: colors.Color | str = colors.blue, + neutral_hue: colors.Color | str = colors.blue, + spacing_size: sizes.Size | str = sizes.spacing_md, + radius_size: sizes.Size | str = sizes.radius_md, + text_size: sizes.Size | str = sizes.text_lg, + font: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("Quicksand"), + "ui-sans-serif", + "sans-serif", + ), + font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("IBM Plex Mono"), + "ui-monospace", + "monospace", + ), + ): + super().__init__( + primary_hue=primary_hue, + secondary_hue=secondary_hue, + neutral_hue=neutral_hue, + spacing_size=spacing_size, + radius_size=radius_size, + text_size=text_size, + font=font, + font_mono=font_mono, + ) + super().set( + button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", + button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", + button_primary_text_color="white", + button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", + slider_color="*secondary_300", + slider_color_dark="*secondary_600", + block_title_text_weight="600", + block_border_width="3px", + block_shadow="*shadow_drop_lg", + button_shadow="*shadow_drop_lg", + button_small_padding="0px", + button_large_padding="3px", + ) diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec3fcac25de3cc7d239c4903403d1a4cd81567b --- /dev/null +++ b/fish_speech/webui/manage.py @@ -0,0 +1,1239 @@ +from __future__ import annotations + +import os + +os.environ["USE_LIBUV"] = "0" +import datetime +import html +import json +import platform +import shutil +import signal +import subprocess +import sys +from pathlib import Path + +import gradio as gr +import psutil +import yaml +from loguru import logger +from tqdm import tqdm + +PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") +sys.path.insert(0, "") +print(sys.path) +cur_work_dir = Path(os.getcwd()).resolve() +print("You are in ", str(cur_work_dir)) + +from fish_speech.i18n import i18n +from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html + +config_path = cur_work_dir / "fish_speech" / "configs" +vqgan_yml_path = config_path / "firefly_gan_vq.yaml" +llama_yml_path = config_path / "text2semantic_finetune.yaml" + +env = os.environ.copy() +env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" + +seafoam = Seafoam() + + +def build_html_error_message(error): + return f""" +
+ {html.escape(error)} +
+ """ + + +def build_html_ok_message(msg): + return f""" +
+ {html.escape(msg)} +
+ """ + + +def build_html_href(link, desc, msg): + return f""" + + {html.escape(msg)} + {desc} + + """ + + +def load_data_in_raw(path): + with open(path, "r", encoding="utf-8") as file: + data = file.read() + return str(data) + + +def kill_proc_tree(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + # Process already terminated + return + + children = parent.children(recursive=True) + for child in children: + try: + os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + if including_parent: + try: + os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + + +system = platform.system() +p_label = None +p_infer = None +p_tensorboard = None + + +def kill_process(pid): + if system == "Windows": + cmd = "taskkill /t /f /pid %s" % pid + # os.system(cmd) + subprocess.run(cmd) + else: + kill_proc_tree(pid) + + +def change_label(if_label): + global p_label + if if_label == True and p_label is None: + url = "http://localhost:3000" + remote_url = "https://text-labeler.pages.dev/" + try: + p_label = subprocess.Popen( + [ + ( + "asr-label-linux-x64" + if sys.platform == "linux" + else "asr-label-win-x64.exe" + ) + ] + ) + except FileNotFoundError: + logger.warning("asr-label execution not found!") + + yield build_html_href( + link=remote_url, + desc=i18n("Optional online ver"), + msg=i18n("Opened labeler in browser"), + ) + + elif if_label == False and p_label is not None: + kill_process(p_label.pid) + p_label = None + yield build_html_ok_message("Nothing") + + +def clean_infer_cache(): + import tempfile + + temp_dir = Path(tempfile.gettempdir()) + gradio_dir = str(temp_dir / "gradio") + try: + shutil.rmtree(gradio_dir) + logger.info(f"Deleted cached audios: {gradio_dir}") + except PermissionError: + logger.info(f"Permission denied: Unable to delete {gradio_dir}") + except FileNotFoundError: + logger.info(f"{gradio_dir} was not found") + except Exception as e: + logger.info(f"An error occurred: {e}") + + +def change_infer( + if_infer, + host, + port, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, +): + global p_infer + if if_infer == True and p_infer == None: + env = os.environ.copy() + + env["GRADIO_SERVER_NAME"] = host + env["GRADIO_SERVER_PORT"] = port + # 启动第二个进程 + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Inferring interface is launched at {}").format(url) + ) + + clean_infer_cache() + + p_infer = subprocess.Popen( + [ + PYTHON, + "tools/webui.py", + "--decoder-checkpoint-path", + infer_decoder_model, + "--decoder-config-name", + infer_decoder_config, + "--llama-checkpoint-path", + infer_llama_model, + ] + + (["--compile"] if infer_compile == "Yes" else []), + env=env, + ) + + elif if_infer == False and p_infer is not None: + kill_process(p_infer.pid) + p_infer = None + yield build_html_error_message(i18n("Infer interface is closed")) + + +js = load_data_in_raw("fish_speech/webui/js/animate.js") +css = load_data_in_raw("fish_speech/webui/css/style.css") + +data_pre_output = (cur_work_dir / "data").resolve() +default_model_output = (cur_work_dir / "results").resolve() +default_filelist = data_pre_output / "detect.list" +data_pre_output.mkdir(parents=True, exist_ok=True) + +items = [] +dict_items = {} + + +def load_yaml_data_in_fact(yml_path): + with open(yml_path, "r", encoding="utf-8") as file: + yml = yaml.safe_load(file) + return yml + + +def write_yaml_data_in_fact(yml, yml_path): + with open(yml_path, "w", encoding="utf-8") as file: + yaml.safe_dump(yml, file, allow_unicode=True) + return yml + + +def generate_tree(directory, depth=0, max_depth=None, prefix=""): + if max_depth is not None and depth > max_depth: + return "" + + tree_str = "" + files = [] + directories = [] + for item in os.listdir(directory): + if os.path.isdir(os.path.join(directory, item)): + directories.append(item) + else: + files.append(item) + + entries = directories + files + for i, entry in enumerate(entries): + connector = "├── " if i < len(entries) - 1 else "└── " + tree_str += f"{prefix}{connector}{entry}
" + if i < len(directories): + extension = "│ " if i < len(entries) - 1 else " " + tree_str += generate_tree( + os.path.join(directory, entry), + depth + 1, + max_depth, + prefix=prefix + extension, + ) + return tree_str + + +def new_explorer(data_path, max_depth): + return gr.Markdown( + elem_classes=["scrollable-component"], + value=generate_tree(data_path, max_depth=max_depth), + ) + + +def add_item( + folder: str, + method: str, + label_lang: str, + if_initial_prompt: bool, + initial_prompt: str | None, +): + folder = folder.strip(" ").strip('"') + + folder_path = Path(folder) + + if folder and folder not in items and data_pre_output not in folder_path.parents: + if folder_path.is_dir(): + items.append(folder) + dict_items[folder] = dict( + type="folder", + method=method, + label_lang=label_lang, + initial_prompt=initial_prompt if if_initial_prompt else None, + ) + elif folder: + err = folder + return gr.Checkboxgroup(choices=items), build_html_error_message( + i18n("Invalid path: {}").format(err) + ) + + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info("After Adding: " + formatted_data) + gr.Info(formatted_data) + return gr.Checkboxgroup(choices=items), build_html_ok_message( + i18n("Added path successfully!") + ) + + +def remove_items(selected_items): + global items, dict_items + to_remove = [item for item in items if item in selected_items] + for item in to_remove: + del dict_items[item] + items = [item for item in items if item in dict_items.keys()] + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info(formatted_data) + gr.Warning("After Removing: " + formatted_data) + return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( + i18n("Removed path successfully!") + ) + + +def show_selected(options): + selected_options = ", ".join(options) + + if options: + return i18n("Selected: {}").format(selected_options) + else: + return i18n("No selected options") + + +from pydub import AudioSegment + + +def convert_to_mono_in_place(audio_path: Path): + audio = AudioSegment.from_file(audio_path) + if audio.channels > 1: + mono_audio = audio.set_channels(1) + mono_audio.export(audio_path, format=audio_path.suffix[1:]) + logger.info(f"Convert {audio_path} successfully") + + +def list_copy(list_file_path, method): + wav_root = data_pre_output + lst = [] + with list_file_path.open("r", encoding="utf-8") as file: + for line in tqdm(file, desc="Processing audio/transcript"): + wav_path, speaker_name, language, text = line.strip().split("|") + original_wav_path = Path(wav_path) + target_wav_path = ( + wav_root / original_wav_path.parent.name / original_wav_path.name + ) + lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") + if target_wav_path.is_file(): + continue + target_wav_path.parent.mkdir(parents=True, exist_ok=True) + if method == i18n("Copy"): + shutil.copy(original_wav_path, target_wav_path) + else: + shutil.move(original_wav_path, target_wav_path.parent) + convert_to_mono_in_place(target_wav_path) + original_lab_path = original_wav_path.with_suffix(".lab") + target_lab_path = ( + wav_root + / original_wav_path.parent.name + / original_wav_path.with_suffix(".lab").name + ) + if target_lab_path.is_file(): + continue + if method == i18n("Copy"): + shutil.copy(original_lab_path, target_lab_path) + else: + shutil.move(original_lab_path, target_lab_path.parent) + + if method == i18n("Move"): + with list_file_path.open("w", encoding="utf-8") as file: + file.writelines("\n".join(lst)) + + del lst + return build_html_ok_message(i18n("Use filelist")) + + +def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): + global dict_items + data_path = Path(data_path) + gr.Warning("Pre-processing begins...") + for item, content in dict_items.items(): + item_path = Path(item) + tar_path = data_path / item_path.name + + if content["type"] == "folder" and item_path.is_dir(): + if content["method"] == i18n("Copy"): + os.makedirs(tar_path, exist_ok=True) + shutil.copytree( + src=str(item_path), dst=str(tar_path), dirs_exist_ok=True + ) + elif not tar_path.is_dir(): + shutil.move(src=str(item_path), dst=str(tar_path)) + + for suf in ["wav", "flac", "mp3"]: + for audio_path in tar_path.glob(f"**/*.{suf}"): + convert_to_mono_in_place(audio_path) + + cur_lang = content["label_lang"] + initial_prompt = content["initial_prompt"] + + transcribe_cmd = [ + PYTHON, + "tools/whisper_asr.py", + "--model-size", + label_model, + "--device", + label_device, + "--audio-dir", + tar_path, + "--save-dir", + tar_path, + "--language", + cur_lang, + ] + + if initial_prompt is not None: + transcribe_cmd += ["--initial-prompt", initial_prompt] + + if cur_lang != "IGNORE": + try: + gr.Warning("Begin To Transcribe") + subprocess.run( + transcribe_cmd, + env=env, + ) + except Exception: + print("Transcription error occurred") + + elif content["type"] == "file" and item_path.is_file(): + list_copy(item_path, content["method"]) + + return build_html_ok_message(i18n("Move files successfully")), new_explorer( + data_path, max_depth=max_depth + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +def train_process( + data_path: str, + option: str, + # llama config + llama_ckpt, + llama_base_config, + llama_lr, + llama_maxsteps, + llama_data_num_workers, + llama_data_batch_size, + llama_data_max_length, + llama_precision, + llama_check_interval, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, +): + + backend = "nccl" if sys.platform == "linux" else "gloo" + + new_project = generate_folder_name() + print("New Project Name: ", new_project) + + if option == "VQGAN": + msg = "Skipped VQGAN Training." + gr.Warning(msg) + logger.info(msg) + + if option == "LLAMA": + msg = "LLAMA Training begins..." + gr.Warning(msg) + logger.info(msg) + subprocess.run( + [ + PYTHON, + "tools/vqgan/extract_vq.py", + str(data_pre_output), + "--num-workers", + "1", + "--batch-size", + "16", + "--config-name", + "firefly_gan_vq", + "--checkpoint-path", + "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ] + ) + + subprocess.run( + [ + PYTHON, + "tools/llama/build_dataset.py", + "--input", + str(data_pre_output), + "--text-extension", + ".lab", + "--num-workers", + "16", + ] + ) + ckpt_path = "checkpoints/fish-speech-1.4/model.pth" + lora_prefix = "lora_" if llama_use_lora else "" + llama_name = lora_prefix + "text2semantic_" + new_project + latest = next( + iter( + sorted( + [ + str(p.relative_to("results")) + for p in Path("results").glob(lora_prefix + "text2sem*/") + ], + reverse=True, + ) + ), + llama_name, + ) + project = ( + llama_name + if llama_ckpt == i18n("new") + else ( + latest + if llama_ckpt == i18n("latest") + else Path(llama_ckpt).relative_to("results") + ) + ) + logger.info(project) + + if llama_check_interval > llama_maxsteps: + llama_check_interval = llama_maxsteps + + train_cmd = [ + PYTHON, + "fish_speech/train.py", + "--config-name", + "text2semantic_finetune", + f"project={project}", + f"trainer.strategy.process_group_backend={backend}", + f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"model.optimizer.lr={llama_lr}", + f"trainer.max_steps={llama_maxsteps}", + f"data.num_workers={llama_data_num_workers}", + f"data.batch_size={llama_data_batch_size}", + f"max_length={llama_data_max_length}", + f"trainer.precision={llama_precision}", + f"trainer.val_check_interval={llama_check_interval}", + f"trainer.accumulate_grad_batches={llama_grad_batches}", + f"train_dataset.interactive_prob={llama_use_speaker}", + ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) + logger.info(train_cmd) + subprocess.run(train_cmd) + + return build_html_ok_message(i18n("Training stopped")) + + +def tensorboard_process( + if_tensorboard: bool, + tensorboard_dir: str, + host: str, + port: str, +): + global p_tensorboard + if if_tensorboard == True and p_tensorboard == None: + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Tensorboard interface is launched at {}").format(url) + ) + prefix = ["tensorboard"] + if Path("fishenv").exists(): + prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"] + + p_tensorboard = subprocess.Popen( + prefix + + [ + "--logdir", + tensorboard_dir, + "--host", + host, + "--port", + port, + "--reload_interval", + "120", + ] + ) + elif if_tensorboard == False and p_tensorboard != None: + kill_process(p_tensorboard.pid) + p_tensorboard = None + yield build_html_error_message(i18n("Tensorboard interface is closed")) + + +def fresh_tb_dir(): + return gr.Dropdown( + choices=[str(p) for p in Path("results").glob("**/tensorboard/")] + ) + + +def list_decoder_models(): + paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")] + if not paths: + logger.warning("No decoder model found") + return paths + + +def list_llama_models(): + choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")] + choices = sorted(choices, reverse=True) + if not choices: + logger.warning("No LLaMA model found") + return choices + + +def list_lora_llama_models(): + choices = sorted( + [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True + ) + if not choices: + logger.warning("No LoRA LLaMA model found") + return choices + + +def fresh_decoder_model(): + return gr.Dropdown(choices=list_decoder_models()) + + +def fresh_llama_ckpt(llama_use_lora): + return gr.Dropdown( + choices=[i18n("latest"), i18n("new")] + + ( + [str(p) for p in Path("results").glob("text2sem*/")] + if not llama_use_lora + else [str(p) for p in Path("results").glob("lora_*/")] + ) + ) + + +def fresh_llama_model(): + return gr.Dropdown(choices=list_llama_models()) + + +def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output): + if ( + lora_weight is None + or not Path(lora_weight).exists() + or not Path(llama_weight).exists() + ): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + gr.Warning("Merging begins...") + merge_cmd = [ + PYTHON, + "tools/llama/merge_lora.py", + "--lora-config", + "r_8_alpha_16", + "--lora-weight", + lora_weight, + "--output", + llama_lora_output + "_" + generate_folder_name(), + ] + logger.info(merge_cmd) + subprocess.run(merge_cmd) + return build_html_ok_message(i18n("Merge successfully")) + + +def llama_quantify(llama_weight, quantify_mode): + if llama_weight is None or not Path(llama_weight).exists(): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + + gr.Warning("Quantifying begins...") + + now = generate_folder_name() + quantify_cmd = [ + PYTHON, + "tools/llama/quantize.py", + "--checkpoint-path", + llama_weight, + "--mode", + quantify_mode, + "--timestamp", + now, + ] + logger.info(quantify_cmd) + subprocess.run(quantify_cmd) + if quantify_mode == "int8": + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}" + ) + else: + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}" + ) + return build_html_ok_message( + i18n("Quantify successfully") + f"Path: {quantize_path}" + ) + + +init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) +init_llama_yml = load_yaml_data_in_fact(llama_yml_path) + +with gr.Blocks( + head="", + js=js, + theme=seafoam, + analytics_enabled=False, + title="Fish Speech", +) as demo: + with gr.Row(): + with gr.Column(): + with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")): + with gr.Row(): + textbox = gr.Textbox( + label="\U0000270F " + + i18n("Input Audio & Source Path for Transcription"), + info=i18n("Speaker is identified by the folder name"), + interactive=True, + ) + with gr.Row(equal_height=False): + with gr.Column(): + output_radio = gr.Radio( + label="\U0001F4C1 " + + i18n("Select source file processing method"), + choices=[i18n("Copy"), i18n("Move")], + value=i18n("Copy"), + interactive=True, + ) + with gr.Column(): + error = gr.HTML(label=i18n("Error Message")) + if_label = gr.Checkbox( + label=i18n("Open Labeler WebUI"), scale=0, show_label=True + ) + + with gr.Row(): + label_device = gr.Dropdown( + label=i18n("Labeling Device"), + info=i18n( + "It is recommended to use CUDA, if you have low configuration, use CPU" + ), + choices=["cpu", "cuda"], + value="cuda", + interactive=True, + ) + label_model = gr.Dropdown( + label=i18n("Whisper Model"), + info=i18n("Faster Whisper, Up to 5g GPU memory usage"), + choices=["large-v3", "medium"], + value="large-v3", + interactive=True, + ) + label_radio = gr.Dropdown( + label=i18n("Optional Label Language"), + info=i18n( + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format" + ), + choices=[ + (i18n("Chinese"), "zh"), + (i18n("English"), "en"), + (i18n("Japanese"), "ja"), + (i18n("Disabled"), "IGNORE"), + (i18n("auto"), "auto"), + ], + value="IGNORE", + interactive=True, + ) + + with gr.Row(): + if_initial_prompt = gr.Checkbox( + value=False, + label=i18n("Enable Initial Prompt"), + min_width=120, + scale=0, + ) + initial_prompt = gr.Textbox( + label=i18n("Initial Prompt"), + info=i18n( + "Initial prompt can provide contextual or vocabulary-specific guidance to the model." + ), + placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.", + interactive=False, + ) + + with gr.Row(): + add_button = gr.Button( + "\U000027A1 " + i18n("Add to Processing Area"), + variant="primary", + ) + remove_button = gr.Button( + "\U000026D4 " + i18n("Remove Selected Data") + ) + + with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")): + with gr.Row(): + model_type_radio = gr.Radio( + label=i18n( + "Select the model to be trained (Depending on the Tab page you are on)" + ), + interactive=False, + choices=["VQGAN", "LLAMA"], + value="VQGAN", + ) + with gr.Row(): + with gr.Tabs(): + with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: + gr.HTML("You don't need to train this model!") + + with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page: + with gr.Row(equal_height=False): + llama_use_lora = gr.Checkbox( + label=i18n("Use LoRA"), + info=i18n( + "Use LoRA can save GPU memory, but may reduce the quality of the model" + ), + value=True, + interactive=True, + ) + llama_ckpt = gr.Dropdown( + label=i18n("Select LLAMA ckpt"), + choices=[i18n("latest"), i18n("new")] + + [ + str(p) + for p in Path("results").glob("text2sem*/") + ] + + [str(p) for p in Path("results").glob("lora*/")], + value=i18n("latest"), + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lr_slider = gr.Slider( + label=i18n("Initial Learning Rate"), + info=i18n( + "lr smaller -> usually train slower but more stable" + ), + interactive=True, + minimum=1e-5, + maximum=1e-4, + step=1e-5, + value=5e-5, + ) + llama_maxsteps_slider = gr.Slider( + label=i18n("Maximum Training Steps"), + info=i18n( + "recommend: max_steps = num_audios // batch_size * (2 to 5)" + ), + interactive=True, + minimum=1, + maximum=10000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_base_config = gr.Dropdown( + label=i18n("Model Size"), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + ) + llama_data_num_workers_slider = gr.Slider( + label=i18n("Number of Workers"), + minimum=1, + maximum=32, + step=1, + value=4, + ) + with gr.Row(equal_height=False): + llama_data_batch_size_slider = gr.Slider( + label=i18n("Batch Size"), + interactive=True, + minimum=1, + maximum=32, + step=1, + value=2, + ) + llama_data_max_length_slider = gr.Slider( + label=i18n("Maximum Length per Sample"), + interactive=True, + minimum=1024, + maximum=4096, + step=128, + value=2048, + ) + with gr.Row(equal_height=False): + llama_precision_dropdown = gr.Dropdown( + label=i18n("Precision"), + info=i18n( + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU" + ), + interactive=True, + choices=["32", "bf16-true", "16-mixed"], + value="bf16-true", + ) + llama_check_interval_slider = gr.Slider( + label=i18n("Save model every n steps"), + info=i18n( + "make sure that it's not greater than max_steps" + ), + interactive=True, + minimum=1, + maximum=1000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_grad_batches = gr.Slider( + label=i18n("Accumulate Gradient Batches"), + interactive=True, + minimum=1, + maximum=20, + step=1, + value=init_llama_yml["trainer"][ + "accumulate_grad_batches" + ], + ) + llama_use_speaker = gr.Slider( + label=i18n( + "Probability of applying Speaker Condition" + ), + interactive=True, + minimum=0.1, + maximum=1.0, + step=0.05, + value=init_llama_yml["train_dataset"][ + "interactive_prob" + ], + ) + + with gr.Tab(label=i18n("Merge LoRA"), id=4): + with gr.Row(equal_height=False): + llama_weight = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "checkpoints/fish-speech-1.4/model.pth", + ], + value="checkpoints/fish-speech-1.4/model.pth", + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + lora_weight = gr.Dropdown( + label=i18n("LoRA Model to be merged"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + str(p) + for p in Path("results").glob("lora*/**/*.ckpt") + ], + allow_custom_value=True, + interactive=True, + ) + lora_llama_config = gr.Dropdown( + label=i18n("LLAMA Model Config"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + allow_custom_value=True, + ) + with gr.Row(equal_height=False): + llama_lora_output = gr.Dropdown( + label=i18n("Output Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/merged", + choices=["checkpoints/merged"], + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lora_merge_btn = gr.Button( + value=i18n("Merge"), variant="primary" + ) + + with gr.Tab(label=i18n("Model Quantization"), id=5): + with gr.Row(equal_height=False): + llama_weight_to_quantify = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_llama_models(), + value="checkpoints/fish-speech-1.4", + allow_custom_value=True, + interactive=True, + ) + quantify_mode = gr.Dropdown( + label=i18n("Post-quantification Precision"), + info=i18n( + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase" + ), + choices=["int8", "int4"], + value="int8", + allow_custom_value=False, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_quantify_btn = gr.Button( + value=i18n("Quantify"), variant="primary" + ) + + with gr.Tab(label="Tensorboard", id=6): + with gr.Row(equal_height=False): + tb_host = gr.Textbox( + label=i18n("Tensorboard Host"), value="127.0.0.1" + ) + tb_port = gr.Textbox( + label=i18n("Tensorboard Port"), value="11451" + ) + with gr.Row(equal_height=False): + tb_dir = gr.Dropdown( + label=i18n("Tensorboard Log Path"), + allow_custom_value=True, + choices=[ + str(p) + for p in Path("results").glob("**/tensorboard/") + ], + ) + with gr.Row(equal_height=False): + if_tb = gr.Checkbox( + label=i18n("Open Tensorboard"), + ) + + with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")): + with gr.Column(): + with gr.Row(): + with gr.Accordion( + label="\U0001F5A5 " + + i18n("Inference Server Configuration"), + open=False, + ): + with gr.Row(): + infer_host_textbox = gr.Textbox( + label=i18n("WebUI Host"), value="127.0.0.1" + ) + infer_port_textbox = gr.Textbox( + label=i18n("WebUI Port"), value="7862" + ) + with gr.Row(): + infer_decoder_model = gr.Dropdown( + label=i18n("Decoder Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_decoder_models(), + value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + allow_custom_value=True, + ) + infer_decoder_config = gr.Dropdown( + label=i18n("Decoder Model Config"), + info=i18n("Changing with the Model Path"), + value="firefly_gan_vq", + choices=[ + "firefly_gan_vq", + ], + allow_custom_value=True, + ) + with gr.Row(): + infer_llama_model = gr.Dropdown( + label=i18n("LLAMA Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/fish-speech-1.4", + choices=list_llama_models(), + allow_custom_value=True, + ) + + with gr.Row(): + infer_compile = gr.Radio( + label=i18n("Compile Model"), + info=i18n( + "Compile the model can significantly reduce the inference time, but will increase cold start time" + ), + choices=["Yes", "No"], + value=( + "Yes" if (sys.platform == "linux") else "No" + ), + interactive=is_module_installed("triton"), + ) + + with gr.Row(): + infer_checkbox = gr.Checkbox( + label=i18n("Open Inference Server") + ) + infer_error = gr.HTML(label=i18n("Inference Server Error")) + + with gr.Column(): + train_error = gr.HTML(label=i18n("Training Error")) + checkbox_group = gr.CheckboxGroup( + label="\U0001F4CA " + i18n("Data Source"), + info=i18n( + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list." + ), + elem_classes=["data_src"], + ) + train_box = gr.Textbox( + label=i18n("Data Preprocessing Path"), + value=str(data_pre_output), + interactive=False, + ) + model_box = gr.Textbox( + label="\U0001F4BE " + i18n("Model Output Path"), + value=str(default_model_output), + interactive=False, + ) + + with gr.Accordion( + i18n( + "View the status of the preprocessing folder (use the slider to control the depth of the tree)" + ), + elem_classes=["scrollable-component"], + elem_id="file_accordion", + ): + tree_slider = gr.Slider( + minimum=0, + maximum=3, + value=0, + step=1, + show_label=False, + container=False, + ) + file_markdown = new_explorer(str(data_pre_output), 0) + with gr.Row(equal_height=False): + admit_btn = gr.Button( + "\U00002705 " + i18n("File Preprocessing"), + variant="primary", + ) + fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) + help_button = gr.Button("\U00002753", scale=0, min_width=80) # question + train_btn = gr.Button(i18n("Start Training"), variant="primary") + + footer = load_data_in_raw("fish_speech/webui/html/footer.html") + footer = footer.format( + versions=versions_html(), + api_docs="https://speech.fish.audio/inference/#http-api", + ) + gr.HTML(footer, elem_id="footer") + vqgan_page.select(lambda: "VQGAN", None, model_type_radio) + llama_page.select(lambda: "LLAMA", None, model_type_radio) + add_button.click( + fn=add_item, + inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt], + outputs=[checkbox_group, error], + ) + remove_button.click( + fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] + ) + checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) + help_button.click( + fn=None, + js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' + 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', + ) + if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) + if_initial_prompt.change( + fn=lambda x: gr.Textbox(value="", interactive=x), + inputs=[if_initial_prompt], + outputs=[initial_prompt], + ) + train_btn.click( + fn=train_process, + inputs=[ + train_box, + model_type_radio, + # llama config + llama_ckpt, + llama_base_config, + llama_lr_slider, + llama_maxsteps_slider, + llama_data_num_workers_slider, + llama_data_batch_size_slider, + llama_data_max_length_slider, + llama_precision_dropdown, + llama_check_interval_slider, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, + ], + outputs=[train_error], + ) + if_tb.change( + fn=tensorboard_process, + inputs=[if_tb, tb_dir, tb_host, tb_port], + outputs=[train_error], + ) + tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir]) + infer_decoder_model.change( + fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model] + ) + infer_llama_model.change( + fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model] + ) + llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight]) + admit_btn.click( + fn=check_files, + inputs=[train_box, tree_slider, label_model, label_device], + outputs=[error, file_markdown], + ) + fresh_btn.click( + fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] + ) + llama_use_lora.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + llama_ckpt.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + lora_weight.change( + fn=lambda: gr.Dropdown(choices=list_lora_llama_models()), + inputs=[], + outputs=[lora_weight], + ) + llama_lora_merge_btn.click( + fn=llama_lora_merge, + inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output], + outputs=[train_error], + ) + llama_quantify_btn.click( + fn=llama_quantify, + inputs=[llama_weight_to_quantify, quantify_mode], + outputs=[train_error], + ) + infer_checkbox.change( + fn=change_infer, + inputs=[ + infer_checkbox, + infer_host_textbox, + infer_port_textbox, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, + ], + outputs=[infer_error], + ) + +demo.launch(inbrowser=True) diff --git a/inference.ipynb b/inference.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e690a80d4f751342569ffe78c43b0b5c327c7f7a --- /dev/null +++ b/inference.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fish Speech" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### For Windows User / win用户" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "bat" + } + }, + "outputs": [], + "source": [ + "!chcp 65001" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### For Linux User / Linux 用户" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import locale\n", + "locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For Chinese users, you probably want to use mirror to accelerate downloading\n", + "# !set HF_ENDPOINT=https://hf-mirror.com\n", + "# !export HF_ENDPOINT=https://hf-mirror.com \n", + "\n", + "!huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## WebUI Inference\n", + "\n", + "> You can use --compile to fuse CUDA kernels for faster inference (10x)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "!python tools/webui.py \\\n", + " --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n", + " --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n", + " # --compile" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Break-down CLI Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Encode reference audio: / 从语音生成 prompt: \n", + "\n", + "You should get a `fake.npy` file.\n", + "\n", + "你应该能得到一个 `fake.npy` 文件." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "## Enter the path to the audio file here\n", + "src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n", + "\n", + "!python tools/vqgan/inference.py \\\n", + " -i {src_audio} \\\n", + " --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n", + "\n", + "from IPython.display import Audio, display\n", + "audio = Audio(filename=\"fake.wav\")\n", + "display(audio)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Generate semantic tokens from text: / 从文本生成语义 token:\n", + "\n", + "> This command will create a codes_N file in the working directory, where N is an integer starting from 0.\n", + "\n", + "> You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~300 tokens/second).\n", + "\n", + "> 该命令会在工作目录下创建 codes_N 文件, 其中 N 是从 0 开始的整数.\n", + "\n", + "> 您可以使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 tokens/秒 -> ~300 tokens/秒)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "!python tools/llama/generate.py \\\n", + " --text \"hello world\" \\\n", + " --prompt-text \"The text corresponding to reference audio\" \\\n", + " --prompt-tokens \"fake.npy\" \\\n", + " --checkpoint-path \"checkpoints/fish-speech-1.4\" \\\n", + " --num-samples 2\n", + " # --compile" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "!python tools/vqgan/inference.py \\\n", + " -i \"codes_0.npy\" \\\n", + " --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n", + "\n", + "from IPython.display import Audio, display\n", + "audio = Audio(filename=\"fake.wav\")\n", + "display(audio)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/install_env.bat b/install_env.bat new file mode 100644 index 0000000000000000000000000000000000000000..590cd5cb6c866313385315937c85c1cb7db84df2 --- /dev/null +++ b/install_env.bat @@ -0,0 +1,179 @@ +@echo off +chcp 65001 + +set USE_MIRROR=true +echo "USE_MIRROR: %USE_MIRROR%" +setlocal enabledelayedexpansion + +cd /D "%~dp0" + +set PATH="%PATH%";%SystemRoot%\system32 + +echo %PATH% + + +echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && ( + echo. + echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && ( + goto end + ) +) + + +set TMP=%CD%\fishenv +set TEMP=%CD%\fishenv + +(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul + +set INSTALL_DIR=%cd%\fishenv +set CONDA_ROOT_PREFIX=%cd%\fishenv\conda +set INSTALL_ENV_DIR=%cd%\fishenv\env +set PIP_CMD=%cd%\fishenv\env\python -m pip +set PYTHON_CMD=%cd%\fishenv\env\python +set API_FLAG_PATH=%~dp0API_FLAGS.txt +set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe +if "!USE_MIRROR!" == "true" ( + set MINICONDA_DOWNLOAD_URL=https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe +) +set MINICONDA_CHECKSUM=307194e1f12bbeb52b083634e89cc67db4f7980bd542254b43d3309eaf7cb358 +set conda_exists=F + +call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1 +if "%ERRORLEVEL%" EQU "0" set conda_exists=T + +if "%conda_exists%" == "F" ( + echo. + echo Downloading Miniconda... + mkdir "%INSTALL_DIR%" 2>nul + call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe" + if errorlevel 1 ( + echo. + echo Failed to download miniconda. + goto end + ) + for /f %%a in (' + certutil -hashfile "%INSTALL_DIR%\miniconda_installer.exe" sha256 + ^| find /i /v " " + ^| find /i "%MINICONDA_CHECKSUM%" + ') do ( + set "hash=%%a" + ) + if not defined hash ( + echo. + echo Miniconda hash mismatched! + del "%INSTALL_DIR%\miniconda_installer.exe" + goto end + ) else ( + echo. + echo Miniconda hash matched successfully. + ) + echo Downloaded "%CONDA_ROOT_PREFIX%" + start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX% + + call "%CONDA_ROOT_PREFIX%\_conda.exe" --version + if errorlevel 1 ( + echo. + echo Cannot install Miniconda. + goto end + ) else ( + echo. + echo Miniconda Install success. + ) + + del "%INSTALL_DIR%\miniconda_installer.exe" +) + + +if not exist "%INSTALL_ENV_DIR%" ( + echo. + echo Creating Conda Environment... + if "!USE_MIRROR!" == "true" ( + call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ python=3.10 + ) else ( + call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10 + ) + + if errorlevel 1 ( + echo. + echo Failed to Create Environment. + goto end + ) +) + +if not exist "%INSTALL_ENV_DIR%\python.exe" ( + echo. + echo Conda Env does not exist. + goto end +) + +set PYTHONNOUSERSITE=1 +set PYTHONPATH= +set PYTHONHOME= +set "CUDA_PATH=%INSTALL_ENV_DIR%" +set "CUDA_HOME=%CUDA_PATH%" + +call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" + +if errorlevel 1 ( + echo. + echo Failed to activate Env. + goto end +) else ( + echo. + echo successfully create env. +) + +set "HF_ENDPOINT=https://huggingface.co" +set "no_proxy=" +if "%USE_MIRROR%"=="true" ( + set "HF_ENDPOINT=https://hf-mirror.com" + set "no_proxy=localhost,127.0.0.1,0.0.0.0" +) + +echo "HF_ENDPOINT: !HF_ENDPOINT!" +echo "NO_PROXY: !no_proxy!" + +%PIP_CMD% install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + +%PIP_CMD% install -e . --upgrade-strategy only-if-needed + +call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^ + "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true" ^ + "2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a" + + +endlocal +echo "Environment Check: Success." +pause + +goto :EOF + + +:download_and_install +setlocal + +set "WHEEL_FILE=%1" +set "URL=%2" +set "CHKSUM=%3" + +:DOWNLOAD +if not exist "%WHEEL_FILE%" ( + call curl -Lk "%URL%" --output "%WHEEL_FILE%" +) + +for /f "delims=" %%I in ("certutil -hashfile %WHEEL_FILE% SHA256 ^| find /i %CHKSUM%") do ( + set "FILE_VALID=true" +) + +if not defined FILE_VALID ( + echo File checksum does not match, re-downloading... + del "%WHEEL_FILE%" + goto DOWNLOAD +) + +echo "OK for %WHEEL_FILE%" +%PIP_CMD% install "%WHEEL_FILE%" --no-warn-script-location +del "%WHEEL_FILE%" + +endlocal +goto :EOF diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000000000000000000000000000000000000..01dcbe571e53012d8e6ebeac7da3dde8503c7f5d --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,109 @@ +site_name: Fish Speech +site_description: Targeting SOTA TTS solutions. +site_url: https://speech.fish.audio + +# Repository +repo_name: fishaudio/fish-speech +repo_url: https://github.com/fishaudio/fish-speech +edit_uri: blob/main/docs + +# Copyright +copyright: Copyright © 2023-2024 by Fish Audio + +theme: + name: material + language: en + features: + - content.action.edit + - content.action.view + - navigation.tracking + - navigation.footer + # - navigation.tabs + - search + - search.suggest + - search.highlight + - search.share + - content.code.copy + icon: + logo: fontawesome/solid/fish + + palette: + # Palette toggle for automatic mode + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + toggle: + icon: material/brightness-7 + name: Switch to dark mode + primary: black + font: + code: Roboto Mono + + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + toggle: + icon: material/brightness-4 + name: Switch to light mode + primary: black + font: + code: Roboto Mono + +# Plugins +plugins: + - search: + separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' + lang: + - en + - zh + - ja + - pt + - i18n: + docs_structure: folder + languages: + - locale: en + name: English + default: true + build: true + - locale: zh + name: 简体中文 + build: true + - locale: ja + name: 日本語 + build: true + - locale: pt + name: Português (Brasil) + build: true + +markdown_extensions: + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences + - admonition + - pymdownx.details + - pymdownx.superfences + - attr_list + - md_in_html + - pymdownx.superfences + +extra_css: + - stylesheets/extra.css + +extra: + social: + - icon: fontawesome/brands/discord + link: https://discord.gg/Es5qTB9BcN + - icon: fontawesome/brands/docker + link: https://hub.docker.com/r/fishaudio/fish-speech + - icon: fontawesome/brands/qq + link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093 + homepage: https://speech.fish.audio diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..23de46286e3ca2b9b7ee0d2d9cdabce8621b31e7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[project] +name = "fish-speech" +version = "0.1.0" +authors = [ + {name = "Lengyue", email = "lengyue@lengyue.me"}, +] +description = "Fish Speech" +readme = "README.md" +requires-python = ">=3.10" +keywords = ["TTS", "Speech"] +license = {text = "CC BY-NC-SA 4.0"} +classifiers = [ + "Programming Language :: Python :: 3", +] +dependencies = [ + "numpy<=1.26.4", + "transformers>=4.35.2", + "datasets==2.18.0", + "lightning>=2.1.0", + "hydra-core>=1.3.2", + "tensorboard>=2.14.1", + "natsort>=8.4.0", + "einops>=0.7.0", + "librosa>=0.10.1", + "rich>=13.5.3", + "gradio>=4.0.0", + "wandb>=0.15.11", + "grpcio>=1.58.0", + "kui>=1.6.0", + "uvicorn>=0.30.0", + "loguru>=0.6.0", + "loralib>=0.1.2", + "natsort>=8.4.0", + "pyrootutils>=1.0.4", + "vector_quantize_pytorch>=1.14.24", + "resampy>=0.4.3", + "einx[torch]==0.2.2", + "zstandard>=0.22.0", + "pydub", + "faster_whisper", + "modelscope==1.17.1", + "funasr==1.1.5", + "opencc-python-reimplemented==0.1.7", + "silero-vad", + "ormsgpack", +] + +[project.optional-dependencies] +stable = [ + "torch>=2.3.1", + "torchaudio", +] + +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["fish_speech", "tools"] diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..ad1493530f7f6d8fa476dbe0b76e6239fce2d7e7 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,6 @@ +{ + "exclude": [ + "data", + "filelists" + ] +} diff --git a/run_cmd.bat b/run_cmd.bat new file mode 100644 index 0000000000000000000000000000000000000000..c2af8a9b6fb75df7b7c81ff5986286845e247fb9 --- /dev/null +++ b/run_cmd.bat @@ -0,0 +1,50 @@ +@echo off +chcp 65001 + +set no_proxy="127.0.0.1, 0.0.0.0, localhost" +setlocal enabledelayedexpansion + +cd /D "%~dp0" + +set PATH="%PATH%";%SystemRoot%\system32 + + +echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && ( + echo. + echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && ( + goto end + ) +) + + +set TMP=%CD%\fishenv +set TEMP=%CD%\fishenv + + +(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul + + +set CONDA_ROOT_PREFIX=%cd%\fishenv\conda +set INSTALL_ENV_DIR=%cd%\fishenv\env + + +set PYTHONNOUSERSITE=1 +set PYTHONPATH=%~dp0 +set PYTHONHOME= + + +call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" + +if errorlevel 1 ( + echo. + echo Environment activation failed. + goto end +) else ( + echo. + echo Environment activation succeeded. +) + +cmd /k "%*" + +:end +pause diff --git a/start.bat b/start.bat new file mode 100644 index 0000000000000000000000000000000000000000..40c7f4d3bfd340753079add8e8e6b5db7abc3fcc --- /dev/null +++ b/start.bat @@ -0,0 +1,97 @@ +@echo off +chcp 65001 + +set USE_MIRROR=true +set PYTHONPATH=%~dp0 +set PYTHON_CMD=python +if exist "fishenv" ( + set PYTHON_CMD=%cd%\fishenv\env\python +) + +set API_FLAG_PATH=%~dp0API_FLAGS.txt +set KMP_DUPLICATE_LIB_OK=TRUE + +setlocal enabledelayedexpansion + +set "HF_ENDPOINT=https://huggingface.co" +set "no_proxy=" +if "%USE_MIRROR%" == "true" ( + set "HF_ENDPOINT=https://hf-mirror.com" + set "no_proxy=localhost, 127.0.0.1, 0.0.0.0" +) +echo "HF_ENDPOINT: !HF_ENDPOINT!" +echo "NO_PROXY: !no_proxy!" + +echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && ( + echo. + echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && ( + goto end + ) +) + +%PYTHON_CMD% .\tools\download_models.py + +set "API_FLAGS=" +set "flags=" + +if exist "%API_FLAG_PATH%" ( + for /f "usebackq tokens=*" %%a in ("%API_FLAG_PATH%") do ( + set "line=%%a" + if not "!line:~0,1!"=="#" ( + set "line=!line: =!" + set "line=!line:\=!" + set "line=!line:= !" + if not "!line!"=="" ( + set "API_FLAGS=!API_FLAGS!!line! " + ) + ) + ) +) + + +if not "!API_FLAGS!"=="" set "API_FLAGS=!API_FLAGS:~0,-1!" + +set "flags=" + +echo !API_FLAGS! | findstr /C:"--api" >nul 2>&1 +if !errorlevel! equ 0 ( + echo. + echo Start HTTP API... + set "mode=api" + goto process_flags +) + +echo !API_FLAGS! | findstr /C:"--infer" >nul 2>&1 +if !errorlevel! equ 0 ( + echo. + echo Start WebUI Inference... + set "mode=infer" + goto process_flags +) + + +:process_flags +for %%p in (!API_FLAGS!) do ( + if not "%%p"=="--!mode!" ( + set "flags=!flags! %%p" + ) +) + +if not "!flags!"=="" set "flags=!flags:~1!" + +echo Debug: flags = !flags! + +if "!mode!"=="api" ( + %PYTHON_CMD% -m tools.api !flags! +) else if "!mode!"=="infer" ( + %PYTHON_CMD% -m tools.webui !flags! +) + +echo. +echo Next launch the page... +%PYTHON_CMD% fish_speech\webui\manage.py + + +:end +endlocal +pause diff --git a/tools/api.py b/tools/api.py new file mode 100644 index 0000000000000000000000000000000000000000..fb885005c1890ada79881110f228dcf355cdd75b --- /dev/null +++ b/tools/api.py @@ -0,0 +1,440 @@ +import base64 +import io +import json +import queue +import random +import sys +import traceback +import wave +from argparse import ArgumentParser +from http import HTTPStatus +from pathlib import Path +from typing import Annotated, Any, Literal, Optional + +import numpy as np +import ormsgpack +import pyrootutils +import soundfile as sf +import torch +import torchaudio +from baize.datastructures import ContentType +from kui.asgi import ( + Body, + FactoryClass, + HTTPException, + HttpRequest, + HttpView, + JSONResponse, + Kui, + OpenAPI, + StreamResponse, +) +from kui.asgi.routing import MultimethodRoutes +from loguru import logger +from pydantic import BaseModel, Field, conint + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# from fish_speech.models.vqgan.lit_module import VQGAN +from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText +from fish_speech.utils import autocast_exclude_mps +from tools.commons import ServeReferenceAudio, ServeTTSRequest +from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text +from tools.llama.generate import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, + launch_thread_safe_queue, +) +from tools.vqgan.inference import load_model as load_decoder_model + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +# Define utils for web server +async def http_execption_handler(exc: HTTPException): + return JSONResponse( + dict( + statusCode=exc.status_code, + message=exc.content, + error=HTTPStatus(exc.status_code).phrase, + ), + exc.status_code, + exc.headers, + ) + + +async def other_exception_handler(exc: "Exception"): + traceback.print_exc() + + status = HTTPStatus.INTERNAL_SERVER_ERROR + return JSONResponse( + dict(statusCode=status, message=str(exc), error=status.phrase), + status, + ) + + +def load_audio(reference_audio, sr): + if len(reference_audio) > 255 or not Path(reference_audio).exists(): + audio_data = reference_audio + reference_audio = io.BytesIO(audio_data) + + waveform, original_sr = torchaudio.load( + reference_audio, backend="ffmpeg" if sys.platform == "linux" else "soundfile" + ) + + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if original_sr != sr: + resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) + waveform = resampler(waveform) + + audio = waveform.squeeze().numpy() + return audio + + +def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): + if enable_reference_audio and reference_audio is not None: + # Load audios, and prepare basic info here + reference_audio_content = load_audio( + reference_audio, decoder_model.spec_transform.sample_rate + ) + + audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ + None, None, : + ] + audio_lengths = torch.tensor( + [audios.shape[2]], device=decoder_model.device, dtype=torch.long + ) + logger.info( + f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" + ) + + # VQ Encoder + if isinstance(decoder_model, FireflyArchitecture): + prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] + + logger.info(f"Encoded prompt: {prompt_tokens.shape}") + else: + prompt_tokens = None + logger.info("No reference audio provided") + + return prompt_tokens + + +def decode_vq_tokens( + *, + decoder_model, + codes, +): + feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) + logger.info(f"VQ features: {codes.shape}") + + if isinstance(decoder_model, FireflyArchitecture): + # VQGAN Inference + return decoder_model.decode( + indices=codes[None], + feature_lengths=feature_lengths, + )[0].squeeze() + + raise ValueError(f"Unknown model type: {type(decoder_model)}") + + +routes = MultimethodRoutes(base_class=HttpView) + + +def get_content_type(audio_format): + if audio_format == "wav": + return "audio/wav" + elif audio_format == "flac": + return "audio/flac" + elif audio_format == "mp3": + return "audio/mpeg" + else: + return "application/octet-stream" + + +@torch.inference_mode() +def inference(req: ServeTTSRequest): + + idstr: str | None = req.reference_id + if idstr is not None: + ref_folder = Path("references") / idstr + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False + ) + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + + else: + # Parse reference audio aka prompt + refs = req.references + if refs is None: + refs = [] + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + + # LLAMA Inference + request = dict( + device=decoder_model.device, + max_new_tokens=req.max_new_tokens, + text=( + req.text + if not req.normalize + else ChnNormedText(raw_text=req.text).normalize() + ), + top_p=req.top_p, + repetition_penalty=req.repetition_penalty, + temperature=req.temperature, + compile=args.compile, + iterative_prompt=req.chunk_length > 0, + chunk_length=req.chunk_length, + max_length=2048, + prompt_tokens=prompt_tokens, + prompt_text=prompt_texts, + ) + + response_queue = queue.Queue() + llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + if req.streaming: + yield wav_chunk_header() + + segments = [] + while True: + result: WrappedGenerateResponse = response_queue.get() + if result.status == "error": + raise result.response + break + + result: GenerateResponse = result.response + if result.action == "next": + break + + with autocast_exclude_mps( + device_type=decoder_model.device.type, dtype=args.precision + ): + fake_audios = decode_vq_tokens( + decoder_model=decoder_model, + codes=result.codes, + ) + + fake_audios = fake_audios.float().cpu().numpy() + + if req.streaming: + yield (fake_audios * 32768).astype(np.int16).tobytes() + else: + segments.append(fake_audios) + + if req.streaming: + return + + if len(segments) == 0: + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content="No audio generated, please check the input text.", + ) + + fake_audios = np.concatenate(segments, axis=0) + yield fake_audios + + +async def inference_async(req: ServeTTSRequest): + for chunk in inference(req): + yield chunk + + +async def buffer_to_async_generator(buffer): + yield buffer + + +@routes.http.post("/v1/tts") +async def api_invoke_model( + req: Annotated[ServeTTSRequest, Body(exclusive=True)], +): + """ + Invoke model and generate audio + """ + + if args.max_text_length > 0 and len(req.text) > args.max_text_length: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Text is too long, max length is {args.max_text_length}", + ) + + if req.streaming and req.format != "wav": + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content="Streaming only supports WAV format", + ) + + if req.streaming: + return StreamResponse( + iterable=inference_async(req), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + else: + fake_audios = next(inference(req)) + buffer = io.BytesIO() + sf.write( + buffer, + fake_audios, + decoder_model.spec_transform.sample_rate, + format=req.format, + ) + + return StreamResponse( + iterable=buffer_to_async_generator(buffer.getvalue()), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + + +@routes.http.post("/v1/health") +async def api_health(): + """ + Health check + """ + + return JSONResponse({"status": "ok"}) + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--llama-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.4", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-text-length", type=int, default=0) + parser.add_argument("--listen", type=str, default="127.0.0.1:8080") + parser.add_argument("--workers", type=int, default=1) + + return parser.parse_args() + + +# Define Kui app +openapi = OpenAPI( + { + "title": "Fish Speech API", + }, +).routes + + +class MsgPackRequest(HttpRequest): + async def data(self) -> Annotated[Any, ContentType("application/msgpack")]: + if self.content_type == "application/msgpack": + return ormsgpack.unpackb(await self.body) + + raise HTTPException( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + headers={"Accept": "application/msgpack"}, + ) + + +app = Kui( + routes=routes + openapi[1:], # Remove the default route + exception_handlers={ + HTTPException: http_execption_handler, + Exception: other_exception_handler, + }, + factory_class=FactoryClass(http=MsgPackRequest), + cors_config={}, +) + + +if __name__ == "__main__": + + import uvicorn + + args = parse_args() + args.precision = torch.half if args.half else torch.bfloat16 + + logger.info("Loading Llama model...") + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + logger.info("Llama model loaded, loading VQ-GAN model...") + + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("VQ-GAN model loaded, warming up...") + + # Dry run to check if the model is loaded correctly and avoid the first-time latency + list( + inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.2, + temperature=0.7, + emotion=None, + format="wav", + ) + ) + ) + + logger.info(f"Warming up done, starting server at http://{args.listen}") + host, port = args.listen.split(":") + uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info") diff --git a/tools/commons.py b/tools/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..f81cadec1efd6e4f749c279e64a65ea9caaa3f53 --- /dev/null +++ b/tools/commons.py @@ -0,0 +1,35 @@ +from typing import Annotated, Literal, Optional + +from pydantic import BaseModel, Field, conint + + +class ServeReferenceAudio(BaseModel): + audio: bytes + text: str + + +class ServeTTSRequest(BaseModel): + text: str + chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 + # Audio format + format: Literal["wav", "pcm", "mp3"] = "wav" + mp3_bitrate: Literal[64, 128, 192] = 128 + # References audios for in-context learning + references: list[ServeReferenceAudio] = [] + # Reference id + # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ + # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 + reference_id: str | None = None + # Normalize text for en & zh, this increase stability for numbers + normalize: bool = True + mp3_bitrate: Optional[int] = 64 + opus_bitrate: Optional[int] = -1000 + # Balance mode will reduce latency to 300ms, but may decrease stability + latency: Literal["normal", "balanced"] = "normal" + # not usually used below + streaming: bool = False + emotion: Optional[str] = None + max_new_tokens: int = 1024 + top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 + repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 + temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 diff --git a/tools/download_models.py b/tools/download_models.py new file mode 100644 index 0000000000000000000000000000000000000000..9e79c34c43b424a8e47c43dd3edf003634fc667e --- /dev/null +++ b/tools/download_models.py @@ -0,0 +1,55 @@ +import os + +from huggingface_hub import hf_hub_download + + +# Download +def check_and_download_files(repo_id, file_list, local_dir): + os.makedirs(local_dir, exist_ok=True) + for file in file_list: + file_path = os.path.join(local_dir, file) + if not os.path.exists(file_path): + print(f"{file} 不存在,从 Hugging Face 仓库下载...") + hf_hub_download( + repo_id=repo_id, + filename=file, + resume_download=True, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + else: + print(f"{file} 已存在,跳过下载。") + + +# 1st +repo_id_1 = "fishaudio/fish-speech-1.4" +local_dir_1 = "./checkpoints/fish-speech-1.4" +files_1 = [ + "model.pth", + "README.md", + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.json", + "config.json", + "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +] + +# 3rd +repo_id_3 = "fishaudio/fish-speech-1" +local_dir_3 = "./" +files_3 = [ + "ffmpeg.exe", + "ffprobe.exe", +] + +# 4th +repo_id_4 = "SpicyqSama007/fish-speech-packed" +local_dir_4 = "./" +files_4 = [ + "asr-label-win-x64.exe", +] + +check_and_download_files(repo_id_1, files_1, local_dir_1) + +check_and_download_files(repo_id_3, files_3, local_dir_3) +check_and_download_files(repo_id_4, files_4, local_dir_4) diff --git a/tools/extract_model.py b/tools/extract_model.py new file mode 100644 index 0000000000000000000000000000000000000000..97fe62507b7282890319d8dc1eaa3cbca0e1f60a --- /dev/null +++ b/tools/extract_model.py @@ -0,0 +1,21 @@ +import click +import torch +from loguru import logger + + +@click.command() +@click.argument("model_path") +@click.argument("output_path") +def main(model_path, output_path): + if model_path == output_path: + logger.error("Model path and output path are the same") + return + + logger.info(f"Loading model from {model_path}") + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + torch.save(state_dict, output_path) + logger.info(f"Model saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/file.py b/tools/file.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a0597365252e7aecf887897ff391a061275c3f --- /dev/null +++ b/tools/file.py @@ -0,0 +1,125 @@ +import base64 +from pathlib import Path +from typing import Union + +from loguru import logger +from natsort import natsorted + +AUDIO_EXTENSIONS = { + ".mp3", + ".wav", + ".flac", + ".ogg", + ".m4a", + ".wma", + ".aac", + ".aiff", + ".aif", + ".aifc", +} + +VIDEO_EXTENSIONS = { + ".mp4", + ".avi", +} + + +def audio_to_bytes(file_path): + if not file_path or not Path(file_path).exists(): + return None + with open(file_path, "rb") as wav_file: + wav = wav_file.read() + return wav + + +def read_ref_text(ref_text): + path = Path(ref_text) + if path.exists() and path.is_file(): + with path.open("r", encoding="utf-8") as file: + return file.read() + return ref_text + + +def list_files( + path: Union[Path, str], + extensions: set[str] = None, + recursive: bool = False, + sort: bool = True, +) -> list[Path]: + """List files in a directory. + + Args: + path (Path): Path to the directory. + extensions (set, optional): Extensions to filter. Defaults to None. + recursive (bool, optional): Whether to search recursively. Defaults to False. + sort (bool, optional): Whether to sort the files. Defaults to True. + + Returns: + list: List of files. + """ + + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Directory {path} does not exist.") + + files = [file for ext in extensions for file in path.rglob(f"*{ext}")] + + if sort: + files = natsorted(files) + + return files + + +def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: + """ + Load a Bert-VITS2 style filelist. + """ + + files = set() + results = [] + count_duplicated, count_not_found = 0, 0 + + LANGUAGE_TO_LANGUAGES = { + "zh": ["zh", "en"], + "jp": ["jp", "en"], + "en": ["en"], + } + + with open(path, "r", encoding="utf-8") as f: + for line in f.readlines(): + splits = line.strip().split("|", maxsplit=3) + if len(splits) != 4: + logger.warning(f"Invalid line: {line}") + continue + + filename, speaker, language, text = splits + file = Path(filename) + language = language.strip().lower() + + if language == "ja": + language = "jp" + + assert language in ["zh", "jp", "en"], f"Invalid language {language}" + languages = LANGUAGE_TO_LANGUAGES[language] + + if file in files: + logger.warning(f"Duplicated file: {file}") + count_duplicated += 1 + continue + + if not file.exists(): + logger.warning(f"File not found: {file}") + count_not_found += 1 + continue + + results.append((file, speaker, languages, text)) + + if count_duplicated > 0: + logger.warning(f"Total duplicated files: {count_duplicated}") + + if count_not_found > 0: + logger.warning(f"Total files not found: {count_not_found}") + + return results diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5ef120cce2e04b24f0f897e49f022cb1946c97 --- /dev/null +++ b/tools/llama/build_dataset.py @@ -0,0 +1,169 @@ +import itertools +import os +import re +from collections import defaultdict +from functools import partial +from multiprocessing import Pool +from pathlib import Path + +import click +import numpy as np +from loguru import logger +from tqdm import tqdm + +from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData +from fish_speech.datasets.protos.text_data_stream import pack_pb_stream +from tools.file import load_filelist + +# To avoid CPU overload +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" + + +def task_generator_folder(root: Path, text_extension: str): + files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) + files = sorted(files) + + grouped_files = defaultdict(list) + for file in tqdm(files, desc=f"Grouping {root}"): + p = str(file.parent) + speaker = file.parent.name + + try: + if isinstance(text_extension, str): + texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] + else: + texts = [ + file.with_suffix(ext).read_text(encoding="utf-8") + for ext in text_extension + ] + except Exception as e: + logger.error(f"Failed to read text {file}: {e}") + continue + + grouped_files[p].append((speaker, file, texts)) + + logger.info( + f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." + ) + + for i in grouped_files.values(): + subset = [(f, t) for _, f, t in i] + yield i[0][0], subset, "folder" + + +def task_generator_filelist(filelist): + grouped_files = defaultdict(list) + for filename, speaker, _, text in load_filelist(filelist): + grouped_files[speaker].append((Path(filename), [text])) + + logger.info(f"Found {len(grouped_files)} groups in {filelist}") + for speaker, values in grouped_files.items(): + yield speaker, values, "filelist" + + +def run_task(task): + name, subset, source = task + + # Parse the files + sentences = [] + for file, texts in subset: + np_file = file.with_suffix(".npy") + if np_file.exists() is False: + logger.warning(f"Can't find {np_file}") + continue + + new_texts = [] + + for text in texts: + # Simple cleaning: replace { xxx } and < xxx > with space + text = re.sub(r"\{.*?\}", " ", text) + text = re.sub(r"<.*?>", " ", text) + text = re.sub(r"\s+", " ", text) + new_texts.append(text) + + try: + semantics = np.load(np_file) + except Exception as e: + logger.error(f"Failed to parse {file}: {e}") + continue + + if isinstance(semantics, np.ndarray): + semantics = semantics.tolist() + + sentences.append( + Sentence( + texts=new_texts, + semantics=[Semantics(values=s) for s in semantics], + ) + ) + + # Pack the sentences + return pack_pb_stream( + TextData( + source=source, + name=name, + sentences=sentences, + ) + ) + + +@click.command() +@click.option( + "--input", + type=click.Path(path_type=Path), + required=True, + help="A folder containing the dataset or a filelist", + multiple=True, +) +@click.option( + "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" +) +@click.option("--num-workers", type=int, default=16) +@click.option("--text-extension", type=str, default=[".txt"], multiple=True) +@click.option( + "--shard-size", type=int, default=10, help="The maximum size of each shard in mb" +) +def main(input, output, num_workers, text_extension, shard_size): + generator_fns = [] + + for f in input: + assert f.exists(), f"{f} not found" + + if f.is_dir(): + generator_fn = task_generator_folder(f, text_extension) + else: + generator_fn = task_generator_filelist(f) + + generator_fns.append(generator_fn) + + generator_fn = itertools.chain(*generator_fns) + output.mkdir(parents=True, exist_ok=True) + + dataset_fp = None + tar_idx = 0 + written_size = 0 + + with Pool(num_workers) as p: + for result in tqdm(p.imap_unordered(run_task, generator_fn)): + if dataset_fp is None: + dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") + + dataset_fp.write(result) + written_size += len(result) + + if written_size > shard_size * 1024 * 1024: + logger.info(f"Finished writing {tar_idx} shards to {output}") + dataset_fp.close() + dataset_fp = None + written_size = 0 + tar_idx += 1 + + if dataset_fp is not None: + dataset_fp.close() + + logger.info(f"Finished writing {tar_idx + 1} shards to {output}") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py new file mode 100644 index 0000000000000000000000000000000000000000..30d70940487388185381246d8210a49a58e55743 --- /dev/null +++ b/tools/llama/eval_in_context.py @@ -0,0 +1,171 @@ +import pyrootutils +import torch +import torch.nn.functional as F +from matplotlib import pyplot as plt +from transformers import AutoTokenizer + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from torch.utils.data import DataLoader + +from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator +from tools.llama.generate import load_model + + +def smooth( + scalars: list[float], weight: float +) -> list[float]: # Weight between 0 and 1 + last = scalars[0] # First value in the plot (first timestep) + smoothed = list() + for point in scalars: + smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value + smoothed.append(smoothed_val) # Save it + last = smoothed_val # Anchor the last smoothed value + + return smoothed + + +@torch.inference_mode() +def analyze_one_model(loader, config, weight, max_length): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = load_model( + config, + weight, + device, + torch.bfloat16, + max_length, + compile=False, + )[0] + + current_step = 0 + model.eval() + + semantic_loss_sum = torch.zeros( + max_length, + dtype=torch.float32, + device=device, + ) + counter = torch.zeros( + max_length, + dtype=torch.long, + device=device, + ) + + for batch in loader: + batch = {k: v.to(device) for k, v in batch.items()} + + labels = batch["labels"] + outputs = model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.reshape(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + reduction="none", + ) + + codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.reshape(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + reduction="none", + ) + + base_loss = base_loss.reshape(labels[:, 0].shape) + semantic_loss = semantic_loss.reshape(codebook_labels.shape) + + semantic_loss_frame = semantic_loss.mean(-1) + pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks + + for loss_sample, pad in zip(semantic_loss_frame, pad_pos): + semantic_loss_sum[~pad] += loss_sample[~pad] + counter[~pad] += 1 + + current_step += 1 + if current_step == 10: + break + + semantic_loss = semantic_loss.cpu() + counter = counter.cpu() + xs, ys = [], [] + + for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)): + if count > 0: + xs.append(i) + ys.append((loss / count).item()) # for better loss visualization + + smoothed_ys = smooth(ys, 0.95) + + # Unload model + del model + torch.cuda.empty_cache() + + return xs, ys, smoothed_ys + + +def main(): + tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1") + max_length = 4096 + + ds = AutoAugTextDataset( + ["data/protos/sft/云天河"], + tokenizer=tokenizer, + use_speaker=False, + interactive_prob=1.0, + max_length=max_length, + ) + + loader = DataLoader( + ds, + batch_size=8, + collate_fn=TextDataCollator(tokenizer, max_length=max_length), + num_workers=0, + shuffle=False, + ) + + plt.figure(figsize=(10, 5), dpi=200) + + plt.xlabel("Frame") + plt.ylabel("Loss") + plt.yscale("log") + plt.title("Semantic Loss") + plt.grid(which="both", axis="both") + plt.xlim(0, max_length) + + tests = [ + ( + "pertrain-medium", + "dual_ar_2_codebook_medium", + "checkpoints/text2semantic-pretrain-medium-2k-v1.pth", + ), + ( + "sft-medium", + "dual_ar_2_codebook_medium", + "checkpoints/text2semantic-sft-medium-v1.1-4k.pth", + ), + ( + "sft-large", + "dual_ar_2_codebook_large", + "checkpoints/text2semantic-sft-large-v1.1-4k.pth", + ), + ] + + for name, config, weight in tests: + xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length) + plt.plot(xs, smoothed_ys, label=name) + + plt.legend() + plt.savefig("semantic_loss.png") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/generate.py b/tools/llama/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..f4561a2ff0414d0f8f9a330e5a268e937ff6b74e --- /dev/null +++ b/tools/llama/generate.py @@ -0,0 +1,708 @@ +import os +import queue +import threading +import time +from contextlib import nullcontext +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional, Tuple, Union + +import click +import hydra +import numpy as np +import torch +import torch._dynamo.config +import torch._inductor.config +from loguru import logger +from tqdm import tqdm + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.text import clean_text, split_text + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True + +if hasattr(torch._inductor.config, "fx_graph_cache"): + # Experimental feature to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + + +from fish_speech.models.text2semantic.llama import ( + BaseTransformer, + DualARTransformer, + NaiveTransformer, +) + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + temperature: torch.Tensor = 1.0, + top_p: torch.Tensor = 1.0, + repetition_penalty: torch.Tensor = 1.0, +) -> torch.Tensor: + # Apply repetition penalty + if previous_tokens is not None: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + + # Apply top-p sampling + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample( + logits, + previous_tokens: Optional[torch.Tensor] = None, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + probs = logits_to_probs( + logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def decode_one_token_ar( + model: DualARTransformer, + x: torch.Tensor, + input_pos: torch.Tensor, + previous_tokens: torch.Tensor = None, + **sampling_kwargs, +) -> torch.Tensor: + x = model.forward_generate(x, input_pos) + + sampling_kwargs_main = sampling_kwargs.copy() + sampling_kwargs_main["temperature"] = 0.1 + sampling_kwargs_main["top_p"] = 0.1 + sampling_kwargs_main["repetition_penalty"] = 1.0 + + codebooks = [ + sample( + x.logits, + previous_tokens=None, # Disable repetition penalty for the token codebook + **sampling_kwargs_main, + )[0] + ] + + x = x.hidden_states + + # Cleanup the cache + for layer in model.fast_layers: + layer.attention.kv_cache.k_cache.fill_(0) + layer.attention.kv_cache.v_cache.fill_(0) + + for codebook_idx in range(model.config.num_codebooks): + input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long) + logits = model.forward_generate_fast(x, input_pos) + a = sample( + logits, + previous_tokens=( + previous_tokens[codebook_idx + 1] + if previous_tokens is not None + else None + ), + **sampling_kwargs, + )[0] + x = model.fast_embeddings(a) + codebooks.append(a) + + return torch.stack(codebooks, dim=0) + + +def decode_one_token_naive( + model: NaiveTransformer, + x: torch.Tensor, + input_pos: torch.Tensor, + previous_tokens: torch.Tensor = None, + **sampling_kwargs, +) -> torch.Tensor: + x = model.forward_generate(x, input_pos) + + sampling_kwargs_main = sampling_kwargs.copy() + sampling_kwargs_main["temperature"] = 0.1 + sampling_kwargs_main["top_p"] = 0.1 + sampling_kwargs_main["repetition_penalty"] = 1.0 + + codebooks = [ + sample( + x.logits, + previous_tokens=None, # Disable repetition penalty for the token codebook + **sampling_kwargs_main, + )[0] + ] + + for i in range(model.config.num_codebooks): + codebooks.append( + sample( + x.codebook_logits[:, :, i], + previous_tokens=( + previous_tokens[i + 1] if previous_tokens is not None else None + ), + **sampling_kwargs, + )[0] + ) + + return torch.stack(codebooks, dim=0) + + +def decode_n_tokens( + model: NaiveTransformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + im_end_id: int = 4, + decode_one_token=decode_one_token_naive, + **sampling_kwargs, +): + previous_tokens = torch.zeros( + (model.config.num_codebooks + 1, model.config.max_seq_len), + dtype=torch.int, + device=cur_token.device, + ) + + for i in tqdm(range(num_new_tokens)): + # We need to get windowed repeat penalty + win_size = 16 + if i < win_size: + window = previous_tokens[:, :win_size] + else: + window = previous_tokens[:, i - win_size : i] + + with ( + torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ) + if torch.cuda.is_available() + else nullcontext() + ): # Actually better for Inductor to codegen attention here + next_token = decode_one_token( + model=model, + x=cur_token, + input_pos=input_pos, + previous_tokens=window, + **sampling_kwargs, + ) + + input_pos += 1 + cur_token = next_token.view(1, model.config.num_codebooks + 1, -1) + previous_tokens[:, i : i + 1] = next_token.view( + model.config.num_codebooks + 1, -1 + ) + + if cur_token[0, 0, -1] == im_end_id: + break + + return previous_tokens[:, : i + 1] + + +@torch.no_grad() +@torch.inference_mode() +def generate( + *, + model: NaiveTransformer, + prompt: torch.Tensor, + max_new_tokens: int, + im_end_id: int = 4, + decode_one_token=decode_one_token_naive, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(1) + + device, dtype = prompt.device, prompt.dtype + + codebook_dim = 1 + model.config.num_codebooks + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty( + (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device + ) + empty[:, :T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + # Use non-accelerated version for now, to avoid compilation overhead + prefill_decode = ( + decode_one_token_naive + if isinstance(model, NaiveTransformer) + else decode_one_token_ar + ) + + next_token = prefill_decode( + model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs + ) + seq[:, T : T + 1] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + x = decode_n_tokens( + model, + next_token.view(1, codebook_dim, -1), + input_pos, + max_new_tokens - 1, + im_end_id=im_end_id, + decode_one_token=decode_one_token, + **sampling_kwargs, + ) + # x = torch.cat(generated_tokens, dim=1) + seq = seq[:, : T + 1 + x.size(1)] + seq[:, T + 1 :] = x + + return seq + + +def encode_tokens( + tokenizer, + string, + device="cuda", + prompt_tokens=None, + num_codebooks=4, +): + string = clean_text(string) + string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n" + + new_tokens = tokenizer.encode( + string, + add_special_tokens=False, + max_length=10**6, + truncation=False, + ) + tokens = torch.tensor([new_tokens], dtype=torch.int, device=device) + + # Codebooks + zeros = ( + torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device) + * CODEBOOK_PAD_TOKEN_ID + ) + prompt = torch.cat((tokens, zeros), dim=0) + + if prompt_tokens is None: + return prompt + + # Get prompt tokens + if prompt_tokens.ndim == 3: + assert ( + prompt_tokens.shape[0] == 1 + ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)" + prompt_tokens = prompt_tokens[0] + + assert prompt_tokens.ndim == 2 + data = prompt_tokens + 1 + + if prompt_tokens.shape[0] > num_codebooks: + logger.warning( + f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks" + ) + data = data[:num_codebooks] + + # Add pad token for each codebook + data = torch.cat( + (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)), + dim=1, + ) + + # Since 1.0, we use <|semantic|> + s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>") + end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + main_token_ids = ( + torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id + ) + main_token_ids[0, -1] = end_token_id + + data = torch.cat((main_token_ids, data), dim=0) + prompt = torch.cat((prompt, data), dim=1) + + return prompt + + +def load_model(checkpoint_path, device, precision, compile=False): + model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained( + checkpoint_path, load_weights=True + ) + + model = model.to(device=device, dtype=precision) + logger.info(f"Restored model from checkpoint") + + if isinstance(model, DualARTransformer): + decode_one_token = decode_one_token_ar + logger.info("Using DualARTransformer") + else: + decode_one_token = decode_one_token_naive + logger.info("Using NaiveTransformer") + + if compile: + logger.info("Compiling function...") + decode_one_token = torch.compile( + decode_one_token, + fullgraph=True, + backend="inductor" if torch.cuda.is_available() else "aot_eager", + mode="reduce-overhead" if torch.cuda.is_available() else None, + ) + + return model.eval(), decode_one_token + + +@dataclass +class GenerateResponse: + action: Literal["sample", "next"] + codes: Optional[torch.Tensor] = None + text: Optional[str] = None + + +def generate_long( + *, + model, + device: str | torch.device, + decode_one_token: callable, + text: str, + num_samples: int = 1, + max_new_tokens: int = 0, + top_p: int = 0.7, + repetition_penalty: float = 1.5, + temperature: float = 0.7, + compile: bool = False, + iterative_prompt: bool = True, + max_length: int = 2048, + chunk_length: int = 150, + prompt_text: Optional[str | list[str]] = None, + prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None, +): + assert 0 < top_p <= 1, "top_p must be in (0, 1]" + assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)" + assert 0 < temperature < 2, "temperature must be in (0, 2)" + + use_prompt = prompt_text is not None and prompt_tokens is not None + if use_prompt and isinstance(prompt_text, str): + prompt_text = [prompt_text] + prompt_tokens = [prompt_tokens] + + assert use_prompt is False or len(prompt_text) == len( + prompt_tokens + ), "Prompt text and tokens must have the same length" + + model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) + tokenizer = model.tokenizer + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + + encoded = [] + texts = split_text(text, chunk_length) if iterative_prompt else [text] + encoded_prompts = [] + + if use_prompt: + for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)): + encoded_prompts.append( + encode_tokens( + tokenizer, + string=t, + device=device, + prompt_tokens=c, + num_codebooks=model.config.num_codebooks, + ) + ) + + for idx, text in enumerate(texts): + encoded.append( + encode_tokens( + tokenizer, + string=text, + device=device, + num_codebooks=model.config.num_codebooks, + ) + ) + logger.info(f"Encoded text: {text}") + + # Move temperature, top_p, repetition_penalty to device + # This is important so that changing params doesn't trigger recompile + temperature = torch.tensor(temperature, device=device, dtype=torch.float) + top_p = torch.tensor(top_p, device=device, dtype=torch.float) + repetition_penalty = torch.tensor( + repetition_penalty, device=device, dtype=torch.float + ) + + for sample_idx in range(num_samples): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + global_encoded = [] + seg_idx = 0 + + while seg_idx < len(encoded): + logger.info( + f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}" + ) + + seg = encoded[seg_idx] + global_encoded.append(seg) + + lengths = reversed([seg.size(1) for seg in global_encoded]) + + # Pick last 2000 tokens + count = 0 + for i, length in enumerate(lengths): + count += length + if count + length > max_length - 1024 - sum( + t.shape[1] for t in encoded_prompts + ): + break + + if i != 0 and i % 2 == 0: + i -= 1 + + # Rotate the list, always make sure first segment is included to avoid drift + if i < len(global_encoded) - 2: + partial_encoded = global_encoded[:2] + global_encoded[-i:] + else: + partial_encoded = global_encoded + + if use_prompt: + partial_encoded = encoded_prompts + partial_encoded + + cat_encoded = torch.cat(partial_encoded, dim=1) + prompt_length = cat_encoded.size(1) + + t0 = time.perf_counter() + y = generate( + model=model, + prompt=cat_encoded, + max_new_tokens=max_new_tokens, + im_end_id=im_end_id, + decode_one_token=decode_one_token, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + if sample_idx == 0 and seg_idx == 0 and compile: + logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + t = time.perf_counter() - t0 + + tokens_generated = y.size(1) - prompt_length + tokens_sec = tokens_generated / t + logger.info( + f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec" + ) + logger.info( + f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" + ) + + if torch.cuda.is_available(): + logger.info( + f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB" + ) + + # Put the generated tokens + # since there is and tokens, we remove last 2 tokens + codes = y[1:, prompt_length:-1].clone() + codes = codes - 1 + assert (codes >= 0).all(), f"Negative code found" + + decoded = y[:, prompt_length:-1].clone() + # But for global encoding, we should keep the token + + global_encoded.append(decoded) + assert (codes >= 0).all(), f"Negative code found: {codes}" + yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx]) + seg_idx += 1 + + # This indicates the end of the current sample + yield GenerateResponse(action="next") + + +@dataclass +class WrappedGenerateResponse: + status: Literal["success", "error"] + response: Optional[GenerateResponse | Exception] = None + + +@dataclass +class GenerateRequest: + request: dict + response_queue: queue.Queue + + +def launch_thread_safe_queue( + checkpoint_path, + device, + precision, + compile: bool = False, +): + input_queue = queue.Queue() + init_event = threading.Event() + + def worker(): + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile + ) + with torch.device(device): + model.setup_caches( + max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype + ) + init_event.set() + + while True: + item: GenerateRequest | None = input_queue.get() + if item is None: + break + + kwargs = item.request + response_queue = item.response_queue + + try: + for chunk in generate_long( + model=model, decode_one_token=decode_one_token, **kwargs + ): + response_queue.put( + WrappedGenerateResponse(status="success", response=chunk) + ) + except Exception as e: + response_queue.put(WrappedGenerateResponse(status="error", response=e)) + + threading.Thread(target=worker, daemon=True).start() + init_event.wait() + + return input_queue + + +@click.command() +@click.option( + "--text", + type=str, + default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", +) +@click.option("--prompt-text", type=str, default=None, multiple=True) +@click.option( + "--prompt-tokens", + type=click.Path(path_type=Path, exists=True), + default=None, + multiple=True, +) +@click.option("--num-samples", type=int, default=1) +@click.option("--max-new-tokens", type=int, default=1024) +@click.option("--top-p", type=float, default=0.7) +@click.option("--repetition-penalty", type=float, default=1.2) +@click.option("--temperature", type=float, default=0.7) +@click.option( + "--checkpoint-path", + type=click.Path(path_type=Path, exists=True), + default="checkpoints/fish-speech-1.4", +) +@click.option("--device", type=str, default="cuda") +@click.option("--compile/--no-compile", default=False) +@click.option("--seed", type=int, default=42) +@click.option("--half/--no-half", default=False) +@click.option("--iterative-prompt/--no-iterative-prompt", default=True) +@click.option("--chunk-length", type=int, default=100) +def main( + text: str, + prompt_text: Optional[list[str]], + prompt_tokens: Optional[list[Path]], + num_samples: int, + max_new_tokens: int, + top_p: int, + repetition_penalty: float, + temperature: float, + checkpoint_path: Path, + device: str, + compile: bool, + seed: int, + half: bool, + iterative_prompt: bool, + chunk_length: int, +) -> None: + + precision = torch.half if half else torch.bfloat16 + + if prompt_text is not None and len(prompt_text) != len(prompt_tokens): + raise ValueError( + f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same" + ) + + logger.info("Loading model ...") + t0 = time.time() + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile + ) + with torch.device(device): + model.setup_caches( + max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + logger.info(f"Time to load model: {time.time() - t0:.02f} seconds") + + if prompt_tokens is not None: + prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens] + + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + generator = generate_long( + model=model, + device=device, + decode_one_token=decode_one_token, + text=text, + num_samples=num_samples, + max_new_tokens=max_new_tokens, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + compile=compile, + iterative_prompt=iterative_prompt, + chunk_length=chunk_length, + prompt_text=prompt_text, + prompt_tokens=prompt_tokens, + ) + + idx = 0 + codes = [] + + for response in generator: + if response.action == "sample": + codes.append(response.codes) + logger.info(f"Sampled text: {response.text}") + elif response.action == "next": + if codes: + np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy()) + logger.info(f"Saved codes to codes_{idx}.npy") + logger.info(f"Next sample") + codes = [] + idx += 1 + else: + logger.error(f"Error: {response}") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bd3cbd725c4eccbe78f711d9718dfb278a6aa7 --- /dev/null +++ b/tools/llama/merge_lora.py @@ -0,0 +1,95 @@ +import shutil +from copy import deepcopy +from pathlib import Path + +import click +import hydra +import torch +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger + +from fish_speech.models.text2semantic.llama import BaseTransformer +from fish_speech.models.text2semantic.lora import get_merged_state_dict + + +@click.command() +@click.option("--lora-config", type=str, default="r_8_alpha_16") +@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4") +@click.option("--lora-weight", type=str, required=True) +@click.option("--output", type=str, required=True) +def merge(lora_config, base_weight, lora_weight, output): + output = Path(output) + logger.info( + f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}" + ) + + with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): + cfg = compose(config_name=lora_config) + + lora_config = instantiate(cfg) + logger.info(f"Loaded lora model with config {lora_config}") + + llama_model = BaseTransformer.from_pretrained( + path=base_weight, + load_weights=True, + lora_config=lora_config, + ) + logger.info(f"Loaded llama model") + + llama_state_dict = llama_model.state_dict() + llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k} + llama_state_dict_copy = deepcopy(llama_state_dict) + lora_state_dict = torch.load(lora_weight, map_location="cpu") + + if "state_dict" in llama_state_dict: + llama_state_dict = llama_state_dict["state_dict"] + + if "state_dict" in lora_state_dict: + lora_state_dict = lora_state_dict["state_dict"] + + # remove prefix model. + if any(k.startswith("model.") for k in llama_state_dict.keys()): + llama_state_dict = { + k.replace("model.", ""): v + for k, v in llama_state_dict.items() + if k.startswith("model.") + } + if any(k.startswith("model.") for k in lora_state_dict.keys()): + lora_state_dict = { + k.replace("model.", ""): v + for k, v in lora_state_dict.items() + if k.startswith("model.") + } + + logger.info(f"Found {len(llama_state_dict)} keys in llama model") + logger.info(f"Found {len(lora_state_dict)} keys in lora model") + + merged_state_dict = llama_state_dict | lora_state_dict + llama_model.load_state_dict(merged_state_dict, strict=True) + logger.info(f"Merged model loaded") + + # Trigger eval mode to merge lora + llama_model.eval() + llama_model.save_pretrained(output, drop_lora=True) + logger.info(f"Saved merged model to {output}, validating") + + new_state_dict = torch.load(output / "model.pth", map_location="cpu") + original_keys = set(llama_state_dict_copy.keys()) + merged_keys = set(new_state_dict.keys()) + + assert original_keys == merged_keys, "Keys should be same" + + for key in original_keys: + diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() + if diff_l1 != 0: + break + else: + logger.error("Merged model is same as the original model") + exit(1) + + logger.info("Merged model is different from the original model, check passed") + + +if __name__ == "__main__": + merge() diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e629d944b5d1e262f6c0517480980fcac01dad86 --- /dev/null +++ b/tools/llama/quantize.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import datetime +import shutil + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path + +import click +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fish_speech.models.text2semantic.llama import find_multiple +from tools.llama.generate import load_model + +##### Quantization Primitives ###### + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + + +##### Weight-only int8 per-channel quantized code ###### + + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr( + module, + name, + WeightOnlyInt8Linear(child.in_features, child.out_features), + ) + else: + replace_linear_weight_only_int8_per_channel(child) + + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, torch.int8 + ) + cur_state_dict[f"{fqn}.weight"] = int8_weight + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + +##### weight only int4 per channel groupwise quantized code ###### + + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x, weight_int4pack, groupsize, scales_and_zeros + ) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=False, + ), + ) + elif padding: + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=True, + ), + ) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding: + import torch.nn.functional as F + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) + continue + ( + weight_int4pack, + scales_and_zeros, + ) = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to("cuda"), + self.groupsize, + self.inner_k_tiles, + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + padding: bool = True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales_and_zeros", + torch.empty( + (in_features // groupsize, out_features, 2), dtype=torch.bfloat16 + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +@click.command() +@click.option( + "--checkpoint-path", + type=click.Path(path_type=Path, exists=True), + default="checkpoints/fish-speech-1.4", +) +@click.option( + "--mode", type=str, default="int8", help="type of quantization to perform" +) +@click.option( + "--groupsize", type=int, default=128, help="Group size for int4 quantization." +) +@click.option("--timestamp", type=str, default="None", help="When to do quantization") +def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None: + + device = "cpu" + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + model, _ = load_model( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=False, + ) + vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + now = timestamp if timestamp != "None" else generate_folder_name() + + if mode == "int8": + print( + "Quantizing model weights for int8 weight-only symmetric per-channel quantization" + ) + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path + dst_name = Path(f"checkpoints/fs-1.2-int8-{now}") + shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) + if (dst_name / vq_model).exists(): + (dst_name / vq_model).unlink() + quantize_path = dst_name / "model.pth" + + elif mode == "int4": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" + ) + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path + dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}") + shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) + if (dst_name / vq_model).exists(): + (dst_name / vq_model).unlink() + quantize_path = dst_name / "model.pth" + + else: + raise ValueError( + f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" + ) + + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink(missing_ok=True) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + + +if __name__ == "__main__": + quantize() diff --git a/tools/llama/rebuild_tokenizer.py b/tools/llama/rebuild_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea64fa6788833000c8dc41e3d570dd5b250fb14b --- /dev/null +++ b/tools/llama/rebuild_tokenizer.py @@ -0,0 +1,57 @@ +from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +# Initialize a tokenizer +tokenizer = Tokenizer(models.BPE()) + +# Customize pre-tokenization and decoding +tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) +tokenizer.decoder = decoders.ByteLevel() +tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + +# Don't train the tokenizer +trainer = trainers.BpeTrainer( + vocab_size=0, + min_frequency=2, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + special_tokens=[ + "<|begin_of_sequence|>", + "<|end_of_sequence|>", + "<|im_start|>", + "<|im_sep|>", # system, user, assistant, etc. + "<|im_end|>", + "<|semantic|>", # audio features + "<|pad|>", + ], +) + +# <|im_start|>user<|im_sep|>...<|im_end|> +# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|> +tokenizer.train_from_iterator([], trainer=trainer) + +print(len(tokenizer.get_vocab())) +x = tokenizer.encode( + "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>" +).ids +print(x, len(x)) +print(tokenizer.decode(x, skip_special_tokens=True)) + + +tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + pad_token="<|pad|>", + bos_token="<|begin_of_sequence|>", + eos_token="<|end_of_sequence|>", +) + +# Try tokenizing a new sequence +sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>" +encoded = tokenizer(sequence).input_ids + +print("Test encoding....") +print(f"\tSentence: {sequence}") +print(f"\tEncoded: {encoded}") +print(f"\tDecoded: {tokenizer.batch_decode(encoded)}") +print(f"\tDecoded: {tokenizer.decode(encoded)}") + +tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True) diff --git a/tools/msgpack_api.py b/tools/msgpack_api.py new file mode 100644 index 0000000000000000000000000000000000000000..67f907bf55283f96f07d89b734403209290421c9 --- /dev/null +++ b/tools/msgpack_api.py @@ -0,0 +1,34 @@ +import httpx +import ormsgpack + +from tools.commons import ServeReferenceAudio, ServeTTSRequest + +# priority: ref_id > references +request = ServeTTSRequest( + text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", + # reference_id="114514", + references=[ + ServeReferenceAudio( + audio=open("lengyue.wav", "rb").read(), + text=open("lengyue.lab", "r", encoding="utf-8").read(), + ) + ], + streaming=True, +) + +with ( + httpx.Client() as client, + open("hello.wav", "wb") as f, +): + with client.stream( + "POST", + "http://127.0.0.1:8080/v1/tts", + content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + headers={ + "authorization": "Bearer YOUR_API_KEY", + "content-type": "application/msgpack", + }, + timeout=None, + ) as response: + for chunk in response.iter_bytes(): + f.write(chunk) diff --git a/tools/post_api.py b/tools/post_api.py new file mode 100644 index 0000000000000000000000000000000000000000..c20dc455c3ec5a6c69b879537c57cddb13495ce1 --- /dev/null +++ b/tools/post_api.py @@ -0,0 +1,205 @@ +import argparse +import base64 +import wave + +import ormsgpack +import pyaudio +import requests +from pydub import AudioSegment +from pydub.playback import play + +from tools.commons import ServeReferenceAudio, ServeTTSRequest +from tools.file import audio_to_bytes, read_ref_text + + +def parse_args(): + + parser = argparse.ArgumentParser( + description="Send a WAV file and text to a server and receive synthesized audio." + ) + + parser.add_argument( + "--url", + "-u", + type=str, + default="http://127.0.0.1:8080/v1/tts", + help="URL of the server", + ) + parser.add_argument( + "--text", "-t", type=str, required=True, help="Text to be synthesized" + ) + parser.add_argument( + "--reference_id", + "-id", + type=str, + default=None, + help="ID of the reference model o be used for the speech", + ) + parser.add_argument( + "--reference_audio", + "-ra", + type=str, + nargs="+", + default=None, + help="Path to the WAV file", + ) + parser.add_argument( + "--reference_text", + "-rt", + type=str, + nargs="+", + default=None, + help="Reference text for voice synthesis", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="generated_audio", + help="Output audio file name", + ) + parser.add_argument( + "--play", + type=bool, + default=True, + help="Whether to play audio after receiving data", + ) + parser.add_argument("--normalize", type=bool, default=True) + parser.add_argument( + "--format", type=str, choices=["wav", "mp3", "flac"], default="wav" + ) + parser.add_argument("--mp3_bitrate", type=int, default=64) + parser.add_argument("--opus_bitrate", type=int, default=-1000) + parser.add_argument("--latency", type=str, default="normal", help="延迟选项") + parser.add_argument( + "--max_new_tokens", + type=int, + default=1024, + help="Maximum new tokens to generate", + ) + parser.add_argument( + "--chunk_length", type=int, default=100, help="Chunk length for synthesis" + ) + parser.add_argument( + "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis" + ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.2, + help="Repetition penalty for synthesis", + ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Temperature for sampling" + ) + parser.add_argument( + "--speaker", type=str, default=None, help="Speaker ID for voice synthesis" + ) + parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion") + parser.add_argument( + "--streaming", type=bool, default=False, help="Enable streaming response" + ) + parser.add_argument( + "--channels", type=int, default=1, help="Number of audio channels" + ) + parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio") + + return parser.parse_args() + + +if __name__ == "__main__": + + args = parse_args() + + idstr: str | None = args.reference_id + # priority: ref_id > [{text, audio},...] + if idstr is None: + ref_audios = args.reference_audio + ref_texts = args.reference_text + if ref_audios is None: + byte_audios = [] + else: + byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios] + if ref_texts is None: + ref_texts = [] + else: + ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts] + else: + byte_audios = [] + ref_texts = [] + pass # in api.py + + data = { + "text": args.text, + "references": [ + ServeReferenceAudio(audio=ref_audio, text=ref_text) + for ref_text, ref_audio in zip(ref_texts, byte_audios) + ], + "reference_id": idstr, + "normalize": args.normalize, + "format": args.format, + "mp3_bitrate": args.mp3_bitrate, + "opus_bitrate": args.opus_bitrate, + "max_new_tokens": args.max_new_tokens, + "chunk_length": args.chunk_length, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "temperature": args.temperature, + "speaker": args.speaker, + "emotion": args.emotion, + "streaming": args.streaming, + } + + pydantic_data = ServeTTSRequest(**data) + + response = requests.post( + args.url, + data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + stream=args.streaming, + headers={ + "authorization": "Bearer YOUR_API_KEY", + "content-type": "application/msgpack", + }, + ) + + if response.status_code == 200: + if args.streaming: + p = pyaudio.PyAudio() + audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format + stream = p.open( + format=audio_format, channels=args.channels, rate=args.rate, output=True + ) + + wf = wave.open(f"{args.output}.wav", "wb") + wf.setnchannels(args.channels) + wf.setsampwidth(p.get_sample_size(audio_format)) + wf.setframerate(args.rate) + + stream_stopped_flag = False + + try: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + stream.write(chunk) + wf.writeframesraw(chunk) + else: + if not stream_stopped_flag: + stream.stop_stream() + stream_stopped_flag = True + finally: + stream.close() + p.terminate() + wf.close() + else: + audio_content = response.content + audio_path = f"{args.output}.{args.format}" + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_content) + + audio = AudioSegment.from_file(audio_path, format=args.format) + if args.play: + play(audio) + print(f"Audio has been saved to '{audio_path}'.") + else: + print(f"Request failed with status code {response.status_code}") + print(response.json()) diff --git a/tools/sensevoice/README.md b/tools/sensevoice/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9a2078aa2d96dfafb445384316f2041d9e819e63 --- /dev/null +++ b/tools/sensevoice/README.md @@ -0,0 +1,59 @@ +# FunASR Command Line Interface + +This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files. + +## Requirements + +- Python >= 3.10 +- PyTorch <= 2.3.1 +- ffmpeg, pydub, audio-separator[gpu]. + +## Installation + +Install the required packages: + +```bash +pip install -e .[stable] +``` + +Make sure you have `ffmpeg` installed and available in your `PATH`. + +## Usage + +### Basic Usage + +To run the tool with default settings: + +```bash +python tools/sensevoice/fun_asr.py --audio-dir --save-dir +``` + +## Options + +| Option | Description | +| :-----------------------: | :---------------------------------------------------------------------------: | +| --audio-dir | Directory containing audio or video files. | +| --save-dir | Directory to save processed audio files. | +| --device | Device to use for processing. Options: cuda (default) or cpu. | +| --language | Language of the transcription. Default is auto. | +| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. | +| --punc | Enable punctuation prediction. | +| --denoise | Enable noise reduction (vocal separation). | + +## Example + +To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled: + +```bash +python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise +``` + +## Additional Notes + +- The tool supports `both audio and video files`. Videos will be converted to audio automatically. +- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks. +- The script will automatically create necessary directories in the `--save-dir`. + +## Troubleshooting + +If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency. diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2e186617fe889500d01d95eccdafc5c0248b84 --- /dev/null +++ b/tools/sensevoice/auto_model.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import copy +import json +import logging +import os.path +import random +import re +import string +import time + +import numpy as np +import torch +from funasr.download.download_model_from_hub import download_model +from funasr.download.file import download_from_url +from funasr.register import tables +from funasr.train_utils.load_pretrained_model import load_pretrained_model +from funasr.train_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import export_utils, misc +from funasr.utils.load_utils import load_audio_text_image_video, load_bytes +from funasr.utils.misc import deep_update +from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en +from tqdm import tqdm + +from .vad_utils import merge_vad, slice_padding_audio_samples + +try: + from funasr.models.campplus.cluster_backend import ClusterBackend + from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk +except: + pass + + +def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): + """ """ + data_list = [] + key_list = [] + filelist = [".scp", ".txt", ".json", ".jsonl", ".text"] + + chars = string.ascii_letters + string.digits + if isinstance(data_in, str): + if data_in.startswith("http://") or data_in.startswith("https://"): # url + data_in = download_from_url(data_in) + + if isinstance(data_in, str) and os.path.exists( + data_in + ): # wav_path; filelist: wav.scp, file.jsonl;text.txt; + _, file_extension = os.path.splitext(data_in) + file_extension = file_extension.lower() + if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt; + with open(data_in, encoding="utf-8") as fin: + for line in fin: + key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + if data_in.endswith( + ".jsonl" + ): # file.jsonl: json.dumps({"source": data}) + lines = json.loads(line.strip()) + data = lines["source"] + key = data["key"] if "key" in data else key + else: # filelist, wav.scp, text.txt: id \t data or data + lines = line.strip().split(maxsplit=1) + data = lines[1] if len(lines) > 1 else lines[0] + key = lines[0] if len(lines) > 1 else key + + data_list.append(data) + key_list.append(key) + else: + if key is None: + # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + key = misc.extract_filename_without_extension(data_in) + data_list = [data_in] + key_list = [key] + elif isinstance(data_in, (list, tuple)): + if data_type is not None and isinstance( + data_type, (list, tuple) + ): # mutiple inputs + data_list_tmp = [] + for data_in_i, data_type_i in zip(data_in, data_type): + key_list, data_list_i = prepare_data_iterator( + data_in=data_in_i, data_type=data_type_i + ) + data_list_tmp.append(data_list_i) + data_list = [] + for item in zip(*data_list_tmp): + data_list.append(item) + else: + # [audio sample point, fbank, text] + data_list = data_in + key_list = [] + for data_i in data_in: + if isinstance(data_i, str) and os.path.exists(data_i): + key = misc.extract_filename_without_extension(data_i) + else: + if key is None: + key = "rand_key_" + "".join( + random.choice(chars) for _ in range(13) + ) + key_list.append(key) + + else: # raw text; audio sample point, fbank; bytes + if isinstance(data_in, bytes): # audio bytes + data_in = load_bytes(data_in) + if key is None: + key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + data_list = [data_in] + key_list = [key] + + return key_list, data_list + + +class AutoModel: + + def __init__(self, **kwargs): + + try: + from funasr.utils.version_checker import check_for_update + + print( + "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel" + ) + check_for_update(disable=kwargs.get("disable_update", False)) + except: + pass + + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + logging.basicConfig(level=log_level) + + model, kwargs = self.build_model(**kwargs) + + # if vad_model is not None, build vad model else None + vad_model = kwargs.get("vad_model", None) + vad_kwargs = ( + {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {}) + ) + if vad_model is not None: + logging.info("Building VAD model.") + vad_kwargs["model"] = vad_model + vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") + vad_kwargs["device"] = kwargs["device"] + vad_model, vad_kwargs = self.build_model(**vad_kwargs) + + # if punc_model is not None, build punc model else None + punc_model = kwargs.get("punc_model", None) + punc_kwargs = ( + {} + if kwargs.get("punc_kwargs", {}) is None + else kwargs.get("punc_kwargs", {}) + ) + if punc_model is not None: + logging.info("Building punc model.") + punc_kwargs["model"] = punc_model + punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") + punc_kwargs["device"] = kwargs["device"] + punc_model, punc_kwargs = self.build_model(**punc_kwargs) + + # if spk_model is not None, build spk model else None + spk_model = kwargs.get("spk_model", None) + spk_kwargs = ( + {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) + ) + if spk_model is not None: + logging.info("Building SPK model.") + spk_kwargs["model"] = spk_model + spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") + spk_kwargs["device"] = kwargs["device"] + spk_model, spk_kwargs = self.build_model(**spk_kwargs) + self.cb_model = ClusterBackend().to(kwargs["device"]) + spk_mode = kwargs.get("spk_mode", "punc_segment") + if spk_mode not in ["default", "vad_segment", "punc_segment"]: + logging.error( + "spk_mode should be one of default, vad_segment and punc_segment." + ) + self.spk_mode = spk_mode + + self.kwargs = kwargs + self.model = model + self.vad_model = vad_model + self.vad_kwargs = vad_kwargs + self.punc_model = punc_model + self.punc_kwargs = punc_kwargs + self.spk_model = spk_model + self.spk_kwargs = spk_kwargs + self.model_path = kwargs.get("model_path") + + @staticmethod + def build_model(**kwargs): + assert "model" in kwargs + if "model_conf" not in kwargs: + logging.info( + "download models from model hub: {}".format(kwargs.get("hub", "ms")) + ) + kwargs = download_model(**kwargs) + + set_all_random_seed(kwargs.get("seed", 0)) + + device = kwargs.get("device", "cuda") + if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: + device = "cpu" + kwargs["batch_size"] = 1 + kwargs["device"] = device + + torch.set_num_threads(kwargs.get("ncpu", 4)) + + # build tokenizer + tokenizer = kwargs.get("tokenizer", None) + if tokenizer is not None: + tokenizer_class = tables.tokenizer_classes.get(tokenizer) + tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {})) + kwargs["token_list"] = ( + tokenizer.token_list if hasattr(tokenizer, "token_list") else None + ) + kwargs["token_list"] = ( + tokenizer.get_vocab() + if hasattr(tokenizer, "get_vocab") + else kwargs["token_list"] + ) + vocab_size = ( + len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + ) + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() + else: + vocab_size = -1 + kwargs["tokenizer"] = tokenizer + + # build frontend + frontend = kwargs.get("frontend", None) + kwargs["input_size"] = None + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs.get("frontend_conf", {})) + kwargs["input_size"] = ( + frontend.output_size() if hasattr(frontend, "output_size") else None + ) + kwargs["frontend"] = frontend + # build model + model_class = tables.model_classes.get(kwargs["model"]) + assert model_class is not None, f'{kwargs["model"]} is not registered' + model_conf = {} + deep_update(model_conf, kwargs.get("model_conf", {})) + deep_update(model_conf, kwargs) + model = model_class(**model_conf, vocab_size=vocab_size) + + # init_param + init_param = kwargs.get("init_param", None) + if init_param is not None: + if os.path.exists(init_param): + logging.info(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=model, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + else: + print(f"error, init_param does not exist!: {init_param}") + + # fp16 + if kwargs.get("fp16", False): + model.to(torch.float16) + elif kwargs.get("bf16", False): + model.to(torch.bfloat16) + model.to(device) + + if not kwargs.get("disable_log", True): + tables.print() + + return model, kwargs + + def __call__(self, *args, **cfg): + kwargs = self.kwargs + deep_update(kwargs, cfg) + res = self.model(*args, kwargs) + return res + + def generate(self, input, input_len=None, **cfg): + if self.vad_model is None: + return self.inference(input, input_len=input_len, **cfg) + + else: + return self.inference_with_vad(input, input_len=input_len, **cfg) + + def inference( + self, input, input_len=None, model=None, kwargs=None, key=None, **cfg + ): + kwargs = self.kwargs if kwargs is None else kwargs + if "cache" in kwargs: + kwargs.pop("cache") + deep_update(kwargs, cfg) + model = self.model if model is None else model + model.eval() + + batch_size = kwargs.get("batch_size", 1) + # if kwargs.get("device", "cpu") == "cpu": + # batch_size = 1 + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key + ) + + speed_stats = {} + asr_result_list = [] + num_samples = len(data_list) + disable_pbar = self.kwargs.get("disable_pbar", False) + pbar = ( + tqdm(colour="blue", total=num_samples, dynamic_ncols=True) + if not disable_pbar + else None + ) + time_speech_total = 0.0 + time_escape_total = 0.0 + for beg_idx in range(0, num_samples, batch_size): + end_idx = min(num_samples, beg_idx + batch_size) + data_batch = data_list[beg_idx:end_idx] + key_batch = key_list[beg_idx:end_idx] + batch = {"data_in": data_batch, "key": key_batch} + + if (end_idx - beg_idx) == 1 and kwargs.get( + "data_type", None + ) == "fbank": # fbank + batch["data_in"] = data_batch[0] + batch["data_lengths"] = input_len + + time1 = time.perf_counter() + with torch.no_grad(): + res = model.inference(**batch, **kwargs) + if isinstance(res, (list, tuple)): + results = res[0] if len(res) > 0 else [{"text": ""}] + meta_data = res[1] if len(res) > 1 else {} + time2 = time.perf_counter() + + asr_result_list.extend(results) + + # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() + batch_data_time = meta_data.get("batch_data_time", -1) + time_escape = time2 - time1 + speed_stats["load_data"] = meta_data.get("load_data", 0.0) + speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0) + speed_stats["forward"] = f"{time_escape:0.3f}" + speed_stats["batch_size"] = f"{len(results)}" + speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}" + description = f"{speed_stats}, " + if pbar: + pbar.update(end_idx - beg_idx) + pbar.set_description(description) + time_speech_total += batch_data_time + time_escape_total += time_escape + + if pbar: + # pbar.update(1) + pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") + torch.cuda.empty_cache() + return asr_result_list + + def vad(self, input, input_len=None, **cfg): + kwargs = self.kwargs + # step.1: compute the vad model + deep_update(self.vad_kwargs, cfg) + beg_vad = time.time() + res = self.inference( + input, + input_len=input_len, + model=self.vad_model, + kwargs=self.vad_kwargs, + **cfg, + ) + end_vad = time.time() + # FIX(gcf): concat the vad clips for sense vocie model for better aed + if cfg.get("merge_vad", False): + for i in range(len(res)): + res[i]["value"] = merge_vad( + res[i]["value"], kwargs.get("merge_length_s", 15) * 1000 + ) + elapsed = end_vad - beg_vad + return elapsed, res + + def inference_with_vadres(self, input, vad_res, input_len=None, **cfg): + + kwargs = self.kwargs + + # step.2 compute asr model + model = self.model + deep_update(kwargs, cfg) + batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1) + batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000 + kwargs["batch_size"] = batch_size + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None) + ) + results_ret_list = [] + time_speech_total_all_samples = 1e-6 + + beg_total = time.time() + pbar_total = ( + tqdm(colour="red", total=len(vad_res), dynamic_ncols=True) + if not kwargs.get("disable_pbar", False) + else None + ) + + for i in range(len(vad_res)): + key = vad_res[i]["key"] + vadsegments = vad_res[i]["value"] + input_i = data_list[i] + fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000 + speech = load_audio_text_image_video( + input_i, fs=fs, audio_fs=kwargs.get("fs", 16000) + ) + speech_lengths = len(speech) + n = len(vadsegments) + data_with_index = [(vadsegments[i], i) for i in range(n)] + sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) + results_sorted = [] + + if not len(sorted_data): + results_ret_list.append({"key": key, "text": "", "timestamp": []}) + logging.info("decoding, utt: {}, empty speech".format(key)) + continue + + if len(sorted_data) > 0 and len(sorted_data[0]) > 0: + batch_size = max( + batch_size, sorted_data[0][0][1] - sorted_data[0][0][0] + ) + + if kwargs["device"] == "cpu": + batch_size = 0 + + beg_idx = 0 + beg_asr_total = time.time() + time_speech_total_per_sample = speech_lengths / 16000 + time_speech_total_all_samples += time_speech_total_per_sample + + # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True) + + all_segments = [] + max_len_in_batch = 0 + end_idx = 1 + + for j, _ in enumerate(range(0, n)): + # pbar_sample.update(1) + sample_length = sorted_data[j][0][1] - sorted_data[j][0][0] + potential_batch_length = max(max_len_in_batch, sample_length) * ( + j + 1 - beg_idx + ) + # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0] + if ( + j < n - 1 + and sample_length < batch_size_threshold_ms + and potential_batch_length < batch_size + ): + max_len_in_batch = max(max_len_in_batch, sample_length) + end_idx += 1 + continue + + speech_j, speech_lengths_j, intervals = slice_padding_audio_samples( + speech, speech_lengths, sorted_data[beg_idx:end_idx] + ) + results = self.inference( + speech_j, input_len=None, model=model, kwargs=kwargs, **cfg + ) + + for _b in range(len(speech_j)): + results[_b]["interval"] = intervals[_b] + + if self.spk_model is not None: + # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] + for _b in range(len(speech_j)): + vad_segments = [ + [ + sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0, + sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0, + np.array(speech_j[_b]), + ] + ] + segments = sv_chunk(vad_segments) + all_segments.extend(segments) + speech_b = [i[2] for i in segments] + spk_res = self.inference( + speech_b, + input_len=None, + model=self.spk_model, + kwargs=kwargs, + **cfg, + ) + results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"] + + beg_idx = end_idx + end_idx += 1 + max_len_in_batch = sample_length + if len(results) < 1: + continue + results_sorted.extend(results) + + # end_asr_total = time.time() + # time_escape_total_per_sample = end_asr_total - beg_asr_total + # pbar_sample.update(1) + # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " + # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") + + restored_data = [0] * n + for j in range(n): + index = sorted_data[j][1] + cur = results_sorted[j] + pattern = r"<\|([^|]+)\|>" + emotion_string = re.findall(pattern, cur["text"]) + cur["text"] = re.sub(pattern, "", cur["text"]) + cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string]) + if self.punc_model is not None and len(cur["text"].strip()) > 0: + deep_update(self.punc_kwargs, cfg) + punc_res = self.inference( + cur["text"], + model=self.punc_model, + kwargs=self.punc_kwargs, + **cfg, + ) + cur["text"] = punc_res[0]["text"] + + restored_data[index] = cur + + end_asr_total = time.time() + time_escape_total_per_sample = end_asr_total - beg_asr_total + if pbar_total: + pbar_total.update(1) + pbar_total.set_description( + f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + f"time_speech: {time_speech_total_per_sample: 0.3f}, " + f"time_escape: {time_escape_total_per_sample:0.3f}" + ) + + # end_total = time.time() + # time_escape_total_all_samples = end_total - beg_total + # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, " + # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, " + # f"time_escape_all: {time_escape_total_all_samples:0.3f}") + return restored_data + + def export(self, input=None, **cfg): + """ + + :param input: + :param type: + :param quantize: + :param fallback_num: + :param calib_num: + :param opset_version: + :param cfg: + :return: + """ + + device = cfg.get("device", "cpu") + model = self.model.to(device=device) + kwargs = self.kwargs + deep_update(kwargs, cfg) + kwargs["device"] = device + del kwargs["model"] + model.eval() + + type = kwargs.get("type", "onnx") + + key_list, data_list = prepare_data_iterator( + input, input_len=None, data_type=kwargs.get("data_type", None), key=None + ) + + with torch.no_grad(): + export_dir = export_utils.export(model=model, data_in=data_list, **kwargs) + + return export_dir diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..6789316d5186db69c021758094649553c3638f66 --- /dev/null +++ b/tools/sensevoice/fun_asr.py @@ -0,0 +1,332 @@ +import gc +import os +import re + +from audio_separator.separator import Separator + +os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr" +os.environ["UVR5_CACHE"] = "./.cache/uvr5-models" +import json +import subprocess +from pathlib import Path + +import click +import torch +from loguru import logger +from pydub import AudioSegment +from silero_vad import get_speech_timestamps, load_silero_vad, read_audio +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files +from tools.sensevoice.auto_model import AutoModel + + +def uvr5_cli( + audio_dir: Path, + output_folder: Path, + audio_files: list[Path] | None = None, + output_format: str = "flac", + model: str = "BS-Roformer-Viperx-1297.ckpt", +): + # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"] + sepr = Separator( + model_file_dir=os.environ["UVR5_CACHE"], + output_dir=output_folder, + output_format=output_format, + ) + dictmodel = { + "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt", + "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt", + "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt", + "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt", + } + roformer_model = dictmodel[model] + sepr.load_model(roformer_model) + if audio_files is None: + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + total_files = len(audio_files) + + print(f"{total_files} audio files found") + + res = [] + for audio in tqdm(audio_files, desc="Denoising: "): + file_path = str(audio_dir / audio) + sep_out = sepr.separate(file_path) + if isinstance(sep_out, str): + res.append(sep_out) + elif isinstance(sep_out, list): + res.extend(sep_out) + del sepr + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return res, roformer_model + + +def get_sample_rate(media_path: Path): + result = subprocess.run( + [ + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + str(media_path), + ], + capture_output=True, + text=True, + check=True, + ) + media_info = json.loads(result.stdout) + for stream in media_info.get("streams", []): + if stream.get("codec_type") == "audio": + return stream.get("sample_rate") + return "44100" # Default sample rate if not found + + +def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"): + sr = get_sample_rate(src_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + if src_path.resolve() == out_path.resolve(): + output = str(out_path.with_stem(out_path.stem + f"_{sr}")) + else: + output = str(out_path) + subprocess.run( + [ + "ffmpeg", + "-loglevel", + "error", + "-i", + str(src_path), + "-acodec", + "pcm_s16le" if out_fmt == "wav" else "flac", + "-ar", + sr, + "-ac", + "1", + "-y", + output, + ], + check=True, + ) + return out_path + + +def convert_video_to_audio(video_path: Path, audio_dir: Path): + cur_dir = audio_dir / video_path.relative_to(audio_dir).parent + vocals = [ + p + for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*") + if p.suffix in AUDIO_EXTENSIONS + ] + if len(vocals) > 0: + return vocals[0] + audio_path = cur_dir / f"{video_path.stem}.wav" + convert_to_mono(video_path, audio_path) + return audio_path + + +@click.command() +@click.option("--audio-dir", required=True, help="Directory containing audio files") +@click.option( + "--save-dir", required=True, help="Directory to save processed audio files" +) +@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") +@click.option("--language", default="auto", help="Language of the transcription") +@click.option( + "--max_single_segment_time", + default=20000, + type=int, + help="Maximum of Output single audio duration(ms)", +) +@click.option("--fsmn-vad/--silero-vad", default=False) +@click.option("--punc/--no-punc", default=False) +@click.option("--denoise/--no-denoise", default=False) +@click.option("--save_emo/--no_save_emo", default=False) +def main( + audio_dir: str, + save_dir: str, + device: str, + language: str, + max_single_segment_time: int, + fsmn_vad: bool, + punc: bool, + denoise: bool, + save_emo: bool, +): + + audios_path = Path(audio_dir) + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + video_files = list_files( + path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True + ) + v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files] + + if denoise: + VOCAL = "_(Vocals)" + original_files = [ + p + for p in audios_path.glob("**/*") + if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem + ] + + _, cur_model = uvr5_cli( + audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files + ) + need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")] + need_remove.extend(original_files) + for _ in need_remove: + _.unlink() + vocal_files = [ + p + for p in audios_path.glob("**/*") + if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem + ] + for f in vocal_files: + fn, ext = f.stem, f.suffix + + v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0]) + if v_pos != -1: + new_fn = fn[: v_pos + len(VOCAL)] + new_f = f.with_name(new_fn + ext) + f = f.rename(new_f) + convert_to_mono(f, f, "flac") + f.unlink() + + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + + logger.info("Loading / Downloading Funasr model...") + + model_dir = "iic/SenseVoiceSmall" + + vad_model = "fsmn-vad" if fsmn_vad else None + vad_kwargs = {"max_single_segment_time": max_single_segment_time} + punc_model = "ct-punc" if punc else None + + manager = AutoModel( + model=model_dir, + trust_remote_code=False, + vad_model=vad_model, + vad_kwargs=vad_kwargs, + punc_model=punc_model, + device=device, + ) + + if not fsmn_vad and vad_model is None: + vad_model = load_silero_vad() + + logger.info("Model loaded.") + + pattern = re.compile(r"_\d{3}\.") + + for file_path in tqdm(audio_files, desc="Processing audio file"): + + if pattern.search(file_path.name): + # logger.info(f"Skipping {file_path} as it has already been processed.") + continue + + file_stem = file_path.stem + file_suffix = file_path.suffix + + rel_path = Path(file_path).relative_to(audio_dir) + (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) + + audio = AudioSegment.from_file(file_path) + + cfg = dict( + cache={}, + language=language, # "zh", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + batch_size_s=60, + ) + + if fsmn_vad: + elapsed, vad_res = manager.vad(input=str(file_path), **cfg) + else: + wav = read_audio( + str(file_path) + ) # backend (sox, soundfile, or ffmpeg) required! + audio_key = file_path.stem + audio_val = [] + speech_timestamps = get_speech_timestamps( + wav, + vad_model, + max_speech_duration_s=max_single_segment_time // 1000, + return_seconds=True, + ) + + audio_val = [ + [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)] + for timestamp in speech_timestamps + ] + vad_res = [] + vad_res.append(dict(key=audio_key, value=audio_val)) + + res = manager.inference_with_vadres( + input=str(file_path), vad_res=vad_res, **cfg + ) + + for i, info in enumerate(res): + [start_ms, end_ms] = info["interval"] + text = info["text"] + emo = info["emo"] + sliced_audio = audio[start_ms:end_ms] + audio_save_path = ( + save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}" + ) + sliced_audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}: {text}") + + transcript_save_path = ( + save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab" + ) + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(text) + + if save_emo: + emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo" + with open( + emo_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(emo) + + if audios_path.resolve() == save_path.resolve(): + file_path.unlink() + + +if __name__ == "__main__": + main() + exit(0) + from funasr.utils.postprocess_utils import rich_transcription_postprocess + + # Load the audio file + audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav") + model_dir = "iic/SenseVoiceSmall" + m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") + m.eval() + + res = m.inference( + data_in=f"{kwargs['model_path']}/example/zh.mp3", + language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + ban_emo_unk=False, + **kwargs, + ) + + print(res) + text = rich_transcription_postprocess(res[0][0]["text"]) + print(text) diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3bef75ed8c2841701fff44f7130e91ef8dfdf8cc --- /dev/null +++ b/tools/sensevoice/vad_utils.py @@ -0,0 +1,61 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + + +def slice_padding_fbank(speech, speech_lengths, vad_segments): + speech_list = [] + speech_lengths_list = [] + for i, segment in enumerate(vad_segments): + + bed_idx = int(segment[0][0] * 16) + end_idx = min(int(segment[0][1] * 16), speech_lengths[0]) + speech_i = speech[0, bed_idx:end_idx] + speech_lengths_i = end_idx - bed_idx + speech_list.append(speech_i) + speech_lengths_list.append(speech_lengths_i) + feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) + speech_lengths_pad = torch.Tensor(speech_lengths_list).int() + return feats_pad, speech_lengths_pad + + +def slice_padding_audio_samples(speech, speech_lengths, vad_segments): + speech_list = [] + speech_lengths_list = [] + intervals = [] + for i, segment in enumerate(vad_segments): + bed_idx = int(segment[0][0] * 16) + end_idx = min(int(segment[0][1] * 16), speech_lengths) + speech_i = speech[bed_idx:end_idx] + speech_lengths_i = end_idx - bed_idx + speech_list.append(speech_i) + speech_lengths_list.append(speech_lengths_i) + intervals.append([bed_idx // 16, end_idx // 16]) + + return speech_list, speech_lengths_list, intervals + + +def merge_vad(vad_result, max_length=15000, min_length=0): + new_result = [] + if len(vad_result) <= 1: + return vad_result + time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result] + time_step = sorted(list(set(time_step))) + if len(time_step) == 0: + return [] + bg = 0 + for i in range(len(time_step) - 1): + time = time_step[i] + if time_step[i + 1] - bg < max_length: + continue + if time - bg > min_length: + new_result.append([bg, time]) + # if time - bg < max_length * 1.5: + # new_result.append([bg, time]) + # else: + # split_num = int(time - bg) // max_length + 1 + # spl_l = int(time - bg) // split_num + # for j in range(split_num): + # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l]) + bg = time + new_result.append([bg, time_step[-1]]) + return new_result diff --git a/tools/smart_pad.py b/tools/smart_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..de9dc154f26b2869a7e34f7d4cd95db741ee4c6a --- /dev/null +++ b/tools/smart_pad.py @@ -0,0 +1,60 @@ +import random +from multiprocessing import Pool +from pathlib import Path + +import click +import librosa +import torch.nn.functional as F +import torchaudio +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, list_files + +threshold = 10 ** (-50 / 20.0) + + +def process(file): + waveform, sample_rate = torchaudio.load(str(file), backend="sox") + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + loudness = librosa.feature.rms( + y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True + )[0] + + for i in range(len(loudness) - 1, 0, -1): + if loudness[i] > threshold: + break + + end_silent_time = (len(loudness) - i) * 512 / sample_rate + + if end_silent_time <= 0.3: + random_time = random.uniform(0.3, 0.7) - end_silent_time + waveform = F.pad( + waveform, (0, int(random_time * sample_rate)), mode="constant", value=0 + ) + + for i in range(len(loudness)): + if loudness[i] > threshold: + break + + start_silent_time = i * 512 / sample_rate + + if start_silent_time > 0.02: + waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :] + + torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate) + + +@click.command() +@click.argument("source", type=Path) +@click.option("--num-workers", type=int, default=12) +def main(source, num_workers): + files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True)) + + with Pool(num_workers) as p: + list(tqdm(p.imap_unordered(process, files), total=len(files))) + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py new file mode 100644 index 0000000000000000000000000000000000000000..d24a5f39566c47ea0cb1fc506d463e9c95c3efbc --- /dev/null +++ b/tools/vqgan/create_train_split.py @@ -0,0 +1,83 @@ +import math +from pathlib import Path +from random import Random + +import click +from loguru import logger +from pydub import AudioSegment +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist + + +@click.command() +@click.argument("root", type=click.Path(exists=True, path_type=Path)) +@click.option("--val-ratio", type=float, default=None) +@click.option("--val-count", type=int, default=None) +@click.option("--filelist", default=None, type=Path) +@click.option("--min-duration", default=None, type=float) +@click.option("--max-duration", default=None, type=float) +def main(root, val_ratio, val_count, filelist, min_duration, max_duration): + if filelist: + files = [i[0] for i in load_filelist(filelist)] + else: + files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) + + if min_duration is None and max_duration is None: + filtered_files = list(map(str, [file.relative_to(root) for file in files])) + else: + filtered_files = [] + for file in tqdm(files): + try: + audio = AudioSegment.from_file(str(file)) + duration = len(audio) / 1000.0 + + if min_duration is not None and duration < min_duration: + logger.info( + f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}" + ) + continue + + if max_duration is not None and duration > max_duration: + logger.info( + f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}" + ) + continue + + filtered_files.append(str(file.relative_to(root))) + except Exception as e: + logger.info(f"Error processing {file}: {e}") + + logger.info( + f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering" + ) + + Random(42).shuffle(filtered_files) + + if val_count is None and val_ratio is None: + logger.info("Validation ratio and count not specified, using min(20%, 100)") + val_size = min(100, math.ceil(len(filtered_files) * 0.2)) + elif val_count is not None and val_ratio is not None: + logger.error("Cannot specify both val_count and val_ratio") + return + elif val_count is not None: + if val_count < 1 or val_count > len(filtered_files): + logger.error("val_count must be between 1 and number of files") + return + val_size = val_count + else: + val_size = math.ceil(len(filtered_files) * val_ratio) + + logger.info(f"Using {val_size} files for validation") + + with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: + f.write("\n".join(filtered_files[val_size:])) + + with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: + f.write("\n".join(filtered_files[:val_size])) + + logger.info("Done") + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..c24eb3f46ab57fb02930f233a67299cb31c7d7ba --- /dev/null +++ b/tools/vqgan/extract_vq.py @@ -0,0 +1,227 @@ +import os +import subprocess as sp +import sys +import time +from datetime import timedelta +from functools import lru_cache +from pathlib import Path +from random import Random + +import click +import numpy as np +import torch +import torchaudio +from hydra import compose, initialize +from hydra.utils import instantiate +from lightning import LightningModule +from loguru import logger +from omegaconf import OmegaConf + +from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) +# This file is used to convert the audio files to text files using the Whisper model. +# It's mainly used to generate the training data for the VQ model. + + +RANK = int(os.environ.get("SLURM_PROCID", 0)) +WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1)) + +logger_format = ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} | " + "{extra[rank]} - {message}" +) +logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"}) +logger.remove() +logger.add(sys.stderr, format=logger_format) + + +@lru_cache(maxsize=1) +def get_model( + config_name: str = "firefly_gan_vq", + checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + device: str | torch.device = "cuda", +): + with initialize(version_base="1.3", config_path="../../fish_speech/configs"): + cfg = compose(config_name=config_name) + + model = instantiate(cfg) + state_dict = torch.load( + checkpoint_path, + map_location=device, + ) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + if any("generator" in k for k in state_dict): + state_dict = { + k.replace("generator.", ""): v + for k, v in state_dict.items() + if "generator." in k + } + + model.load_state_dict(state_dict, strict=False) + model.eval() + model.to(device) + + logger.info(f"Loaded model") + return model + + +@torch.inference_mode() +def process_batch(files: list[Path], model) -> float: + wavs = [] + audio_lengths = [] + new_files = [] + max_length = total_time = 0 + + for file in files: + try: + wav, sr = torchaudio.load( + str(file), backend="sox" if sys.platform == "linux" else "soundfile" + ) # Need to install libsox-dev + except Exception as e: + logger.error(f"Error reading {file}: {e}") + continue + + if wav.shape[0] > 1: + wav = wav.mean(dim=0, keepdim=True) + + wav = torchaudio.functional.resample( + wav.cuda(), sr, model.spec_transform.sample_rate + )[0] + total_time += len(wav) / model.spec_transform.sample_rate + max_length = max(max_length, len(wav)) + + wavs.append(wav) + audio_lengths.append(len(wav)) + new_files.append(file) + + files = new_files + + # Pad to max length + for i, wav in enumerate(wavs): + wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant") + + audios = torch.stack(wavs, dim=0)[:, None] + audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long) + + # Calculate lengths + indices, feature_lengths = model.encode(audios, audio_lengths) + + # Save to disk + outputs = indices.cpu().numpy() + + for file, length, feature, audio_length in zip( + files, feature_lengths, outputs, audio_lengths + ): + feature = feature[:, :length] + + # (T,) + with open(file.with_suffix(".npy"), "wb") as f: + np.save(f, feature) + + return total_time + + +@click.command() +@click.argument("folder") +@click.option("--num-workers", default=1) +@click.option("--config-name", default="firefly_gan_vq") +@click.option( + "--checkpoint-path", + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +) +@click.option("--batch-size", default=64) +@click.option("--filelist", default=None, type=Path) +def main( + folder: str, + num_workers: int, + config_name: str, + checkpoint_path: str, + batch_size: int, + filelist: Path, +): + if num_workers > 1 and WORLD_SIZE != num_workers: + assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both" + + logger.info(f"Spawning {num_workers} workers") + + if torch.cuda.is_available(): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if visible_devices is None: + visible_devices = list(range(torch.cuda.device_count())) + else: + visible_devices = visible_devices.split(",") + else: + # Set to empty string to avoid using GPU + visible_devices = [""] + + processes = [] + for i in range(num_workers): + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)]) + env["SLURM_PROCID"] = str(i) + env["SLURM_NTASKS"] = str(num_workers) + + processes.append( + sp.Popen( + [sys.executable] + sys.argv.copy(), + env=env, + ) + ) + + for p in processes: + p.wait() + + logger.info(f"All workers finished") + return + + # This is a worker + logger.info(f"Starting worker") + if filelist: + files = [i[0] for i in load_filelist(filelist)] + else: + files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False) + + print(f"Found {len(files)} files") + files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()] + + total_files = len(files) + files = files[RANK::WORLD_SIZE] + logger.info(f"Processing {len(files)}/{total_files} files") + + # Batch processing + total_time = 0 + begin_time = time.time() + processed_files = 0 + model = get_model(config_name, checkpoint_path) + + for n_batch, idx in enumerate(range(0, len(files), batch_size)): + batch = files[idx : idx + batch_size] + batch_time = process_batch(batch, model) + + total_time += batch_time + processed_files += len(batch) + + if (n_batch + 1) % 10 == 0: + eta = ( + (time.time() - begin_time) + / processed_files + * (len(files) - processed_files) + ) + logger.info( + f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, " + + f"ETA: {timedelta(seconds=round(eta))}s" + ) + + logger.info( + f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio" + ) + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bc7531c41455c346109bdaaa43dafc1e3508a4 --- /dev/null +++ b/tools/vqgan/inference.py @@ -0,0 +1,122 @@ +from pathlib import Path + +import click +import hydra +import numpy as np +import soundfile as sf +import torch +import torchaudio +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger +from omegaconf import OmegaConf + +from tools.file import AUDIO_EXTENSIONS + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + + +def load_model(config_name, checkpoint_path, device="cuda"): + hydra.core.global_hydra.GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../../fish_speech/configs"): + cfg = compose(config_name=config_name) + + model = instantiate(cfg) + state_dict = torch.load( + checkpoint_path, + map_location=device, + ) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + if any("generator" in k for k in state_dict): + state_dict = { + k.replace("generator.", ""): v + for k, v in state_dict.items() + if "generator." in k + } + + result = model.load_state_dict(state_dict, strict=False) + model.eval() + model.to(device) + + logger.info(f"Loaded model: {result}") + return model + + +@torch.no_grad() +@click.command() +@click.option( + "--input-path", + "-i", + default="test.wav", + type=click.Path(exists=True, path_type=Path), +) +@click.option( + "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path) +) +@click.option("--config-name", default="firefly_gan_vq") +@click.option( + "--checkpoint-path", + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +) +@click.option( + "--device", + "-d", + default="cuda", +) +def main(input_path, output_path, config_name, checkpoint_path, device): + model = load_model(config_name, checkpoint_path, device=device) + + if input_path.suffix in AUDIO_EXTENSIONS: + logger.info(f"Processing in-place reconstruction of {input_path}") + + # Load audio + audio, sr = torchaudio.load(str(input_path)) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + audio = torchaudio.functional.resample( + audio, sr, model.spec_transform.sample_rate + ) + + audios = audio[None].to(device) + logger.info( + f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds" + ) + + # VQ Encoder + audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long) + indices = model.encode(audios, audio_lengths)[0][0] + + logger.info(f"Generated indices of shape {indices.shape}") + + # Save indices + np.save(output_path.with_suffix(".npy"), indices.cpu().numpy()) + elif input_path.suffix == ".npy": + logger.info(f"Processing precomputed indices from {input_path}") + indices = np.load(input_path) + indices = torch.from_numpy(indices).to(device).long() + assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}" + else: + raise ValueError(f"Unknown input type: {input_path}") + + # Restore + feature_lengths = torch.tensor([indices.shape[1]], device=device) + fake_audios, _ = model.decode( + indices=indices[None], feature_lengths=feature_lengths + ) + audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate + + logger.info( + f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}" + ) + + # Save audio + fake_audio = fake_audios[0, 0].float().cpu().numpy() + sf.write(output_path, fake_audio, model.spec_transform.sample_rate) + logger.info(f"Saved audio to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/webui.py b/tools/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..cff155d48967b4d3980e280096cafc511009a737 --- /dev/null +++ b/tools/webui.py @@ -0,0 +1,485 @@ +import gc +import html +import io +import os +import queue +import wave +from argparse import ArgumentParser +from functools import partial +from pathlib import Path + +import gradio as gr +import librosa +import numpy as np +import pyrootutils +import torch +from loguru import logger +from transformers import AutoTokenizer + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + + +from fish_speech.i18n import i18n +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText +from fish_speech.utils import autocast_exclude_mps +from tools.api import decode_vq_tokens, encode_reference +from tools.llama.generate import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, + launch_thread_safe_queue, +) +from tools.vqgan.inference import load_model as load_decoder_model + +# Make einx happy +os.environ["EINX_FILTER_TRACEBACK"] = "false" + + +HEADER_MD = f"""# Fish Speech + +{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} + +{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")} + +{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} + +{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} +""" + +TEXTBOX_PLACEHOLDER = i18n("Put your text here.") +SPACE_IMPORTED = False + + +def build_html_error_message(error): + return f""" +
+ {html.escape(str(error))} +
+ """ + + +@torch.inference_mode() +def inference( + text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + streaming=False, +): + if args.max_gradio_length > 0 and len(text) > args.max_gradio_length: + return ( + None, + None, + i18n("Text is too long, please keep it under {} characters.").format( + args.max_gradio_length + ), + ) + + # Parse reference audio aka prompt + prompt_tokens = encode_reference( + decoder_model=decoder_model, + reference_audio=reference_audio, + enable_reference_audio=enable_reference_audio, + ) + + # LLAMA Inference + request = dict( + device=decoder_model.device, + max_new_tokens=max_new_tokens, + text=text, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + compile=args.compile, + iterative_prompt=chunk_length > 0, + chunk_length=chunk_length, + max_length=2048, + prompt_tokens=prompt_tokens if enable_reference_audio else None, + prompt_text=reference_text if enable_reference_audio else None, + ) + + response_queue = queue.Queue() + llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + if streaming: + yield wav_chunk_header(), None, None + + segments = [] + + while True: + result: WrappedGenerateResponse = response_queue.get() + if result.status == "error": + yield None, None, build_html_error_message(result.response) + break + + result: GenerateResponse = result.response + if result.action == "next": + break + + with autocast_exclude_mps( + device_type=decoder_model.device.type, dtype=args.precision + ): + fake_audios = decode_vq_tokens( + decoder_model=decoder_model, + codes=result.codes, + ) + + fake_audios = fake_audios.float().cpu().numpy() + segments.append(fake_audios) + + if streaming: + yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None + + if len(segments) == 0: + return ( + None, + None, + build_html_error_message( + i18n("No audio generated, please check the input text.") + ), + ) + + # No matter streaming or not, we need to return the final audio + audio = np.concatenate(segments, axis=0) + yield None, (decoder_model.spec_transform.sample_rate, audio), None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + +inference_stream = partial(inference, streaming=True) + +n_audios = 4 + +global_audio_list = [] +global_error_list = [] + + +def inference_wrapper( + text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + batch_infer_num, +): + audios = [] + errors = [] + + for _ in range(batch_infer_num): + result = inference( + text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + ) + + _, audio_data, error_message = next(result) + + audios.append( + gr.Audio(value=audio_data if audio_data else None, visible=True), + ) + errors.append( + gr.HTML(value=error_message if error_message else None, visible=True), + ) + + for _ in range(batch_infer_num, n_audios): + audios.append( + gr.Audio(value=None, visible=False), + ) + errors.append( + gr.HTML(value=None, visible=False), + ) + + return None, *audios, *errors + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +def normalize_text(user_input, use_normalization): + if use_normalization: + return ChnNormedText(raw_text=user_input).normalize() + else: + return user_input + + +asr_model = None + + +def build_app(): + with gr.Blocks(theme=gr.themes.Base()) as app: + gr.Markdown(HEADER_MD) + + # Use light theme by default + app.load( + None, + None, + js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" + % args.theme, + ) + + # Inference + with gr.Row(): + with gr.Column(scale=3): + text = gr.Textbox( + label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 + ) + refined_text = gr.Textbox( + label=i18n("Realtime Transform Text"), + placeholder=i18n( + "Normalization Result Preview (Currently Only Chinese)" + ), + lines=5, + interactive=False, + ) + + with gr.Row(): + if_refine_text = gr.Checkbox( + label=i18n("Text Normalization"), + value=False, + scale=1, + ) + + with gr.Row(): + with gr.Tab(label=i18n("Advanced Config")): + chunk_length = gr.Slider( + label=i18n("Iterative Prompt Length, 0 means off"), + minimum=50, + maximum=300, + value=200, + step=8, + ) + + max_new_tokens = gr.Slider( + label=i18n("Maximum tokens per batch, 0 means no limit"), + minimum=0, + maximum=2048, + value=1024, # 0 means no limit + step=8, + ) + + top_p = gr.Slider( + label="Top-P", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + + repetition_penalty = gr.Slider( + label=i18n("Repetition Penalty"), + minimum=1, + maximum=1.5, + value=1.2, + step=0.01, + ) + + temperature = gr.Slider( + label="Temperature", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + + with gr.Tab(label=i18n("Reference Audio")): + gr.Markdown( + i18n( + "5 to 10 seconds of reference audio, useful for specifying speaker." + ) + ) + + enable_reference_audio = gr.Checkbox( + label=i18n("Enable Reference Audio"), + ) + reference_audio = gr.Audio( + label=i18n("Reference Audio"), + type="filepath", + ) + with gr.Row(): + reference_text = gr.Textbox( + label=i18n("Reference Text"), + lines=1, + placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + value="", + ) + with gr.Tab(label=i18n("Batch Inference")): + batch_infer_num = gr.Slider( + label="Batch infer nums", + minimum=1, + maximum=n_audios, + step=1, + value=1, + ) + + with gr.Column(scale=3): + for _ in range(n_audios): + with gr.Row(): + error = gr.HTML( + label=i18n("Error Message"), + visible=True if _ == 0 else False, + ) + global_error_list.append(error) + with gr.Row(): + audio = gr.Audio( + label=i18n("Generated Audio"), + type="numpy", + interactive=False, + visible=True if _ == 0 else False, + ) + global_audio_list.append(audio) + + with gr.Row(): + stream_audio = gr.Audio( + label=i18n("Streaming Audio"), + streaming=True, + autoplay=True, + interactive=False, + show_download_button=True, + ) + with gr.Row(): + with gr.Column(scale=3): + generate = gr.Button( + value="\U0001F3A7 " + i18n("Generate"), variant="primary" + ) + generate_stream = gr.Button( + value="\U0001F3A7 " + i18n("Streaming Generate"), + variant="primary", + ) + + text.input( + fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text] + ) + + # # Submit + generate.click( + inference_wrapper, + [ + refined_text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + batch_infer_num, + ], + [stream_audio, *global_audio_list, *global_error_list], + concurrency_limit=1, + ) + + generate_stream.click( + inference_stream, + [ + refined_text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + ], + [stream_audio, global_audio_list[0], global_error_list[0]], + concurrency_limit=10, + ) + return app + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--llama-checkpoint-path", + type=Path, + default="checkpoints/fish-speech-1.4", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=Path, + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-gradio-length", type=int, default=0) + parser.add_argument("--theme", type=str, default="light") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + args.precision = torch.half if args.half else torch.bfloat16 + + logger.info("Loading Llama model...") + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + logger.info("Llama model loaded, loading VQ-GAN model...") + + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("Decoder model loaded, warming up...") + + # Dry run to check if the model is loaded correctly and avoid the first-time latency + list( + inference( + text="Hello, world!", + enable_reference_audio=False, + reference_audio=None, + reference_text="", + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.2, + temperature=0.7, + ) + ) + + logger.info("Warming up done, launching the web UI...") + + app = build_app() + app.launch(show_api=True) diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..42e7de8a185880d3f2afd368d6df3429488465a4 --- /dev/null +++ b/tools/whisper_asr.py @@ -0,0 +1,176 @@ +""" +Used to transcribe all audio files in one folder into another folder. +e.g. +Directory structure: +--pre_data_root +----SP_1 +------01.wav +------02.wav +------...... +----SP_2 +------01.wav +------02.wav +------...... +Use +python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1 +to transcribe the first speaker. + +Use +python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2 +to transcribe the second speaker. + +Note: Be aware of your audio sample rate, which defaults to 44.1kHz. +""" + +import re +from pathlib import Path + +import click +import soundfile as sf +from faster_whisper import WhisperModel +from loguru import logger +from pydub import AudioSegment +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, list_files + + +@click.command() +@click.option("--model-size", default="large-v3", help="Size of the Whisper model") +@click.option( + "--compute-type", + default="float16", + help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]", +) +@click.option("--audio-dir", required=True, help="Directory containing audio files") +@click.option( + "--save-dir", required=True, help="Directory to save processed audio files" +) +@click.option( + "--sample-rate", + default=44100, + type=int, + help="Output sample rate, default to input sample rate", +) +@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") +@click.option("--language", default="auto", help="Language of the transcription") +@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing") +def main( + model_size, + compute_type, + audio_dir, + save_dir, + sample_rate, + device, + language, + initial_prompt, +): + logger.info("Loading / Downloading Faster Whisper model...") + + model = WhisperModel( + model_size, + device=device, + compute_type=compute_type, + download_root="faster_whisper", + ) + + logger.info("Model loaded.") + + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + + for file_path in tqdm(audio_files, desc="Processing audio file"): + file_stem = file_path.stem + file_suffix = file_path.suffix + + rel_path = Path(file_path).relative_to(audio_dir) + (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) + + audio = AudioSegment.from_file(file_path) + + segments, info = model.transcribe( + file_path, + beam_size=5, + language=None if language == "auto" else language, + initial_prompt=initial_prompt, + ) + + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + print("Total len(ms): ", len(audio)) + + whole_text = None + for segment in segments: + id, start, end, text = ( + segment.id, + segment.start, + segment.end, + segment.text, + ) + print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text)) + if not whole_text: + whole_text = text + else: + whole_text += ", " + text + + whole_text += "." + + audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}" + audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}") + + transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab" + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(whole_text) + + +if __name__ == "__main__": + main() + exit(0) + + audio = AudioSegment.from_wav( + r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav" + ) + + model_size = "large-v3" + + model = WhisperModel( + model_size, + device="cuda", + compute_type="float16", + download_root="faster_whisper", + ) + + segments, info = model.transcribe( + r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav", + beam_size=5, + ) + + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + print("Total len(ms): ", len(audio)) + + for i, segment in enumerate(segments): + print( + "Segment %03d [%.2fs -> %.2fs] %s" + % (i, segment.start, segment.end, segment.text) + ) + start_ms = int(segment.start * 1000) + end_ms = int(segment.end * 1000) + segment_audio = audio[start_ms:end_ms] + segment_audio.export(f"segment_{i:03d}.wav", format="wav") + print(f"Exported segment_{i:03d}.wav") + + print("All segments have been exported.")