diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..23e927f618477bf5819e47366f40bc4ad6a47a59 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +docs/example_crop.gif filter=lfs diff=lfs merge=lfs -text +docs/example_crop_still.gif filter=lfs diff=lfs merge=lfs -text +docs/example_full.gif filter=lfs diff=lfs merge=lfs -text +docs/example_full_enhanced.gif filter=lfs diff=lfs merge=lfs -text +docs/free_view_result.gif filter=lfs diff=lfs merge=lfs -text +docs/resize_good.gif filter=lfs diff=lfs merge=lfs -text +docs/resize_no.gif filter=lfs diff=lfs merge=lfs -text +docs/using_ref_video.gif filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/chinese_news.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/deyu.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/eluosi.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/fayu.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/imagine.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/japanese.wav filter=lfs diff=lfs merge=lfs -text +examples/ref_video/WDA_AlexandriaOcasioCortez_000.mp4 filter=lfs diff=lfs merge=lfs -text +examples/ref_video/WDA_KatieHill_000.mp4 filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..73c66b605152ba0e143900f8140e24655358d940 --- /dev/null +++ b/.gitignore @@ -0,0 +1,174 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +examples/results/* +gfpgan/* +checkpoints/* +assets/* +results/* +Dockerfile +start_docker.sh +start.sh + +checkpoints + +# Mac +.DS_Store diff --git a/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..363fcab7ed6e9634e198cf5555ceb88932c9a245 --- /dev/null +++ b/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/.ipynb_checkpoints/app-checkpoint.py b/.ipynb_checkpoints/app-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..926031c37dc21f6d341da9f050e35f39c87eddf0 --- /dev/null +++ b/.ipynb_checkpoints/app-checkpoint.py @@ -0,0 +1,608 @@ +import os, sys +import tempfile +import gradio as gr +from src.gradio_demo import SadTalker +# from src.utils.text2speech import TTSTalker +from huggingface_hub import snapshot_download + +import torch +import librosa +from scipy.io.wavfile import write +from transformers import WavLMModel + +import utils +from models import SynthesizerTrn +from mel_processing import mel_spectrogram_torch +from speaker_encoder.voice_encoder import SpeakerEncoder + +import time +from textwrap import dedent + +import mdtex2html +from loguru import logger +from transformers import AutoModel, AutoTokenizer + +from tts_voice import tts_order_voice +import edge_tts +import tempfile +import anyio + + +def get_source_image(image): + return image + +try: + import webui # in webui + in_webui = True +except: + in_webui = False + + +def toggle_audio_file(choice): + if choice == False: + return gr.update(visible=True), gr.update(visible=False) + else: + return gr.update(visible=False), gr.update(visible=True) + +def ref_video_fn(path_of_ref_video): + if path_of_ref_video is not None: + return gr.update(value=True) + else: + return gr.update(value=False) + +def download_model(): + REPO_ID = 'vinthony/SadTalker-V002rc' + snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True) + +def sadtalker_demo(): + + download_model() + + sad_talker = SadTalker(lazy_load=True) + # tts_talker = TTSTalker() + +download_model() +sad_talker = SadTalker(lazy_load=True) + + +# ChatGLM2 & FreeVC + +''' +def get_wavlm(): + os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU') + shutil.move('WavLM-Large.pt', 'wavlm') +''' + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt') + +print("Loading FreeVC(24k)...") +hps = utils.get_hparams_from_file("configs/freevc-24.json") +freevc_24 = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model).to(device) +_ = freevc_24.eval() +_ = utils.load_checkpoint("checkpoint/freevc-24.pth", freevc_24, None) + +print("Loading WavLM for content...") +cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device) + +def convert(model, src, tgt): + with torch.no_grad(): + # tgt + wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate) + wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) + if model == "FreeVC" or model == "FreeVC (24kHz)": + g_tgt = smodel.embed_utterance(wav_tgt) + g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device) + else: + wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device) + mel_tgt = mel_spectrogram_torch( + wav_tgt, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax + ) + # src + wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate) + wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device) + c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device) + # infer + if model == "FreeVC": + audio = freevc.infer(c, g=g_tgt) + elif model == "FreeVC-s": + audio = freevc_s.infer(c, mel=mel_tgt) + else: + audio = freevc_24.infer(c, g=g_tgt) + audio = audio[0][0].data.cpu().float().numpy() + if model == "FreeVC" or model == "FreeVC-s": + write("out.wav", hps.data.sampling_rate, audio) + else: + write("out.wav", 24000, audio) + out = "out.wav" + return out + +# GLM2 + +language_dict = tts_order_voice + +# fix timezone in Linux +os.environ["TZ"] = "Asia/Shanghai" +try: + time.tzset() # type: ignore # pylint: disable=no-member +except Exception: + # Windows + logger.warning("Windows, cant run time.tzset()") + +# model_name = "THUDM/chatglm2-6b" +model_name = "THUDM/chatglm2-6b-int4" + +RETRY_FLAG = False + +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + +# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda() + +# 4/8 bit +# model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda() + +has_cuda = torch.cuda.is_available() + +# has_cuda = False # force cpu + +if has_cuda: + model_glm = ( + AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half() + ) # 3.92G +else: + model_glm = AutoModel.from_pretrained( + model_name, trust_remote_code=True + ).float() # .float() .half().float() + +model_glm = model_glm.eval() + +_ = """Override Chatbot.postprocess""" + + +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split("`") + if count % 2 == 1: + lines[i] = f'
'
+ else:
+ lines[i] = "
"
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", r"\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = "'\n",
+ " else:\n",
+ " lines[i] = \"
\"\n",
+ " else:\n",
+ " if i > 0:\n",
+ " if count % 2 == 1:\n",
+ " line = line.replace(\"`\", r\"\\`\")\n",
+ " line = line.replace(\"<\", \"<\")\n",
+ " line = line.replace(\">\", \">\")\n",
+ " line = line.replace(\" \", \" \")\n",
+ " line = line.replace(\"*\", \"*\")\n",
+ " line = line.replace(\"_\", \"_\")\n",
+ " line = line.replace(\"-\", \"-\")\n",
+ " line = line.replace(\".\", \".\")\n",
+ " line = line.replace(\"!\", \"!\")\n",
+ " line = line.replace(\"(\", \"(\")\n",
+ " line = line.replace(\")\", \")\")\n",
+ " line = line.replace(\"$\", \"$\")\n",
+ " lines[i] = \"'
+ else:
+ lines[i] = "
"
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", r"\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = "