kevinwang676 commited on
Commit
8c01f11
1 Parent(s): 5d1e1a3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +25 -0
  2. .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  3. LICENSE +16 -46
  4. README.md +268 -12
  5. SadTalker/.gitignore +174 -0
  6. SadTalker/.ipynb_checkpoints/requirements-checkpoint.txt +37 -0
  7. SadTalker/checkpoints/SadTalker_V0.0.2_256.safetensors +3 -0
  8. SadTalker/checkpoints/SadTalker_V0.0.2_512.safetensors +3 -0
  9. SadTalker/checkpoints/mapping_00109-model.pth.tar +3 -0
  10. SadTalker/checkpoints/mapping_00229-model.pth.tar +3 -0
  11. SadTalker/cog.yaml +35 -0
  12. SadTalker/gfpgan/weights/GFPGANv1.4.pth +3 -0
  13. SadTalker/gfpgan/weights/alignment_WFLW_4HG.pth +3 -0
  14. SadTalker/gfpgan/weights/detection_Resnet50_Final.pth +3 -0
  15. SadTalker/gfpgan/weights/parsing_parsenet.pth +3 -0
  16. SadTalker/inference.py +145 -0
  17. SadTalker/launcher.py +204 -0
  18. SadTalker/predict.py +192 -0
  19. SadTalker/req.txt +22 -0
  20. SadTalker/requirements.txt +37 -0
  21. SadTalker/scripts/download_models.sh +32 -0
  22. SadTalker/scripts/extension.py +189 -0
  23. SadTalker/scripts/test.sh +21 -0
  24. SadTalker/src/audio2exp_models/audio2exp.py +41 -0
  25. SadTalker/src/audio2exp_models/networks.py +74 -0
  26. SadTalker/src/audio2pose_models/audio2pose.py +94 -0
  27. SadTalker/src/audio2pose_models/audio_encoder.py +64 -0
  28. SadTalker/src/audio2pose_models/cvae.py +149 -0
  29. SadTalker/src/audio2pose_models/discriminator.py +76 -0
  30. SadTalker/src/audio2pose_models/networks.py +140 -0
  31. SadTalker/src/audio2pose_models/res_unet.py +65 -0
  32. SadTalker/src/config/auido2exp.yaml +58 -0
  33. SadTalker/src/config/auido2pose.yaml +49 -0
  34. SadTalker/src/config/facerender.yaml +45 -0
  35. SadTalker/src/config/facerender_still.yaml +45 -0
  36. SadTalker/src/config/similarity_Lm3D_all.mat +0 -0
  37. SadTalker/src/face3d/data/__init__.py +116 -0
  38. SadTalker/src/face3d/data/base_dataset.py +125 -0
  39. SadTalker/src/face3d/data/flist_dataset.py +125 -0
  40. SadTalker/src/face3d/data/image_folder.py +66 -0
  41. SadTalker/src/face3d/data/template_dataset.py +75 -0
  42. SadTalker/src/face3d/extract_kp_videos.py +108 -0
  43. SadTalker/src/face3d/extract_kp_videos_safe.py +151 -0
  44. SadTalker/src/face3d/models/__init__.py +67 -0
  45. SadTalker/src/face3d/models/arcface_torch/README.md +164 -0
  46. SadTalker/src/face3d/models/arcface_torch/backbones/__init__.py +25 -0
  47. SadTalker/src/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
  48. SadTalker/src/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
  49. SadTalker/src/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
  50. SadTalker/src/face3d/models/arcface_torch/configs/3millions.py +23 -0
.gitattributes CHANGED
@@ -39,3 +39,28 @@ model/gambino/Hamza.png filter=lfs diff=lfs merge=lfs -text
39
  model/angele/Angele.png filter=lfs diff=lfs merge=lfs -text
40
  model/leto/Leto.png filter=lfs diff=lfs merge=lfs -text
41
  NotoSansSC-Regular.otf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  model/angele/Angele.png filter=lfs diff=lfs merge=lfs -text
40
  model/leto/Leto.png filter=lfs diff=lfs merge=lfs -text
41
  NotoSansSC-Regular.otf filter=lfs diff=lfs merge=lfs -text
42
+ SadTalker/checkpoints/mapping_00109-model.pth.tar filter=lfs diff=lfs merge=lfs -text
43
+ SadTalker/checkpoints/mapping_00229-model.pth.tar filter=lfs diff=lfs merge=lfs -text
44
+ docs/example_crop.gif filter=lfs diff=lfs merge=lfs -text
45
+ docs/example_crop_still.gif filter=lfs diff=lfs merge=lfs -text
46
+ docs/example_full.gif filter=lfs diff=lfs merge=lfs -text
47
+ docs/example_full_enhanced.gif filter=lfs diff=lfs merge=lfs -text
48
+ docs/free_view_result.gif filter=lfs diff=lfs merge=lfs -text
49
+ docs/resize_good.gif filter=lfs diff=lfs merge=lfs -text
50
+ docs/resize_no.gif filter=lfs diff=lfs merge=lfs -text
51
+ docs/using_ref_video.gif filter=lfs diff=lfs merge=lfs -text
52
+ examples/driven_audio/chinese_news.wav filter=lfs diff=lfs merge=lfs -text
53
+ examples/driven_audio/deyu.wav filter=lfs diff=lfs merge=lfs -text
54
+ examples/driven_audio/eluosi.wav filter=lfs diff=lfs merge=lfs -text
55
+ examples/driven_audio/fayu.wav filter=lfs diff=lfs merge=lfs -text
56
+ examples/driven_audio/imagine.wav filter=lfs diff=lfs merge=lfs -text
57
+ examples/driven_audio/japanese.wav filter=lfs diff=lfs merge=lfs -text
58
+ examples/ref_video/WDA_AlexandriaOcasioCortez_000.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ examples/ref_video/WDA_KatieHill_000.mp4 filter=lfs diff=lfs merge=lfs -text
60
+ examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text
61
+ examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text
62
+ examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text
63
+ examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text
64
+ examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text
65
+ examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text
66
+ examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
LICENSE CHANGED
@@ -1,51 +1,21 @@
1
  MIT License
2
 
3
- Copyright (c) 2023 liujing04
4
- Copyright (c) 2023 源文雨
5
- Copyright (c) 2023 on9.moe Webslaves
6
 
7
- 本软件及其相关代码以MIT协议开源,作者不对软件具备任何控制力,使用软件者、传播软件导出的声音者自负全责。
8
- 如不认可该条款,则不能使用或引用软件包内任何代码和文件。
 
 
 
 
9
 
10
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
 
11
 
12
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
13
-
14
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15
-
16
- 特此授予任何获得本软件和相关文档文件(以下简称“软件”)副本的人免费使用、复制、修改、合并、出版、分发、再授权和/或销售本软件的权利,以及授予本软件所提供的人使用本软件的权利,但须符合以下条件:
17
- 上述版权声明和本许可声明应包含在软件的所有副本或实质部分中。
18
- 软件是“按原样”提供的,没有任何明示或暗示的保证,包括但不限于适销性、适用于特定目的和不侵权的保证。在任何情况下,作者或版权持有人均不承担因软件或软件的使用或其他交易而产生、产生或与之相关的任何索赔、损害赔偿或其他责任,无论是在合同诉讼、侵权诉讼还是其他诉讼中。
19
-
20
- 相关引用库协议如下:
21
- #################
22
- ContentVec
23
- https://github.com/auspicious3000/contentvec/blob/main/LICENSE
24
- MIT License
25
- #################
26
- VITS
27
- https://github.com/jaywalnut310/vits/blob/main/LICENSE
28
- MIT License
29
- #################
30
- HIFIGAN
31
- https://github.com/jik876/hifi-gan/blob/master/LICENSE
32
- MIT License
33
- #################
34
- gradio
35
- https://github.com/gradio-app/gradio/blob/main/LICENSE
36
- Apache License 2.0
37
- #################
38
- ffmpeg
39
- https://github.com/FFmpeg/FFmpeg/blob/master/COPYING.LGPLv3
40
- https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2021-02-28-12-32/ffmpeg-n4.3.2-160-gfbb9368226-win64-lgpl-4.3.zip
41
- LPGLv3 License
42
- MIT License
43
- #################
44
- ultimatevocalremovergui
45
- https://github.com/Anjok07/ultimatevocalremovergui/blob/master/LICENSE
46
- https://github.com/yang123qwe/vocal_separation_by_uvr5
47
- MIT License
48
- #################
49
- audio-slicer
50
- https://github.com/openvpi/audio-slicer/blob/main/LICENSE
51
- MIT License
 
1
  MIT License
2
 
3
+ Copyright (c) 2023 Tencent AI Lab
 
 
4
 
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
 
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
 
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,12 +1,268 @@
1
- ---
2
- title: VoiceChange
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.28.3
8
- app_file: app_multi.py
9
- pinned: false
10
- license: mit
11
- duplicated_from: BartPoint/VoiceChange
12
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd368bfe2e0f.png' width='500px'/>
4
+
5
+
6
+ <!--<h2> 😭 SadTalker: <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2> -->
7
+
8
+ <a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-PDF-red'></a> &nbsp; <a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp; [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) &nbsp; [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker) &nbsp; [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) &nbsp; [![Replicate](https://replicate.com/cjwbw/sadtalker/badge)](https://replicate.com/cjwbw/sadtalker)
9
+
10
+ <div>
11
+ <a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a>&emsp;
12
+ <a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a>&emsp;
13
+ <a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a>&emsp;
14
+ <a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a>&emsp;
15
+ <a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>&emsp; </br>
16
+ <a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a>&emsp;
17
+ <a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a>&emsp;
18
+ <a target='_blank'>Fei Wang <sup>1</sup> </a>&emsp;
19
+ </div>
20
+ <br>
21
+ <div>
22
+ <sup>1</sup> Xi'an Jiaotong University &emsp; <sup>2</sup> Tencent AI Lab &emsp; <sup>3</sup> Ant Group &emsp;
23
+ </div>
24
+ <br>
25
+ <i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
26
+ <br>
27
+ <br>
28
+
29
+
30
+ ![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif)
31
+
32
+ <b>TL;DR: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; single portrait image 🙎‍♂️ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; audio 🎤 &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; = &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; talking head video 🎞.</b>
33
+
34
+ <br>
35
+
36
+ </div>
37
+
38
+
39
+
40
+ ## 🔥 Highlight
41
+
42
+ - 🔥 The extension of the [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is online. Checkout more details [here](docs/webui_extension.md).
43
+
44
+ https://user-images.githubusercontent.com/4397546/231495639-5d4bb925-ea64-4a36-a519-6389917dac29.mp4
45
+
46
+ - 🔥 `full image mode` is online! checkout [here](https://github.com/Winfredy/SadTalker#full-bodyimage-generation) for more details.
47
+
48
+ | still+enhancer in v0.0.1 | still + enhancer in v0.0.2 | [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) |
49
+ |:--------------------: |:--------------------: | :----: |
50
+ | <video src="https://user-images.githubusercontent.com/48216707/229484996-5d7be64f-2553-4c9e-a452-c5cf0b8ebafe.mp4" type="video/mp4"> </video> | <video src="https://user-images.githubusercontent.com/4397546/230717873-355b7bf3-d3de-49f9-a439-9220e623fce7.mp4" type="video/mp4"> </video> | <img src='./examples/source_image/full_body_2.png' width='380'>
51
+
52
+ - 🔥 Several new mode, eg, `still mode`, `reference mode`, `resize mode` are online for better and custom applications.
53
+
54
+ - 🔥 Happy to see more community demos at [bilibili](https://search.bilibili.com/all?keyword=sadtalker&from_source=webtop_search&spm_id_from=333.1007&search_source=3
55
+ ), [Youtube](https://www.youtube.com/results?search_query=sadtalker&sp=CAM%253D) and [twitter #sadtalker](https://twitter.com/search?q=%23sadtalker&src=typed_query).
56
+
57
+ ## 📋 Changelog (Previous changelog can be founded [here](docs/changlelog.md))
58
+
59
+ - __[2023.06.12]__: add more new features in WEBUI extension, see the discussion [here](https://github.com/OpenTalker/SadTalker/discussions/386).
60
+
61
+ - __[2023.06.05]__: release a new 512 beta face model. Fixed some bugs and improve the performance.
62
+
63
+ - __[2023.04.15]__: Adding automatic1111 colab by @camenduru, thanks for this awesome colab: [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb).
64
+
65
+ - __[2023.04.12]__: adding a more detailed sd-webui installation document, fixed reinstallation problem.
66
+
67
+ - __[2023.04.12]__: Fixed the sd-webui safe issues becasue of the 3rd packages, optimize the output path in `sd-webui-extension`.
68
+
69
+ - __[2023.04.08]__: ❗️❗️❗️ In v0.0.2, we add a logo watermark to the generated video to prevent abusing since it is very realistic.
70
+
71
+ - __[2023.04.08]__: v0.0.2, full image animation, adding baidu driver for download checkpoints. Optimizing the logic about enhancer.
72
+
73
+
74
+ ## 🚧 TODO: See the Discussion https://github.com/OpenTalker/SadTalker/issues/280
75
+
76
+ ## If you have any problem, please view our [FAQ](docs/FAQ.md) before opening an issue.
77
+
78
+
79
+
80
+ ## ⚙️ 1. Installation.
81
+
82
+ Tutorials from communities: [中文windows教程](https://www.bilibili.com/video/BV1Dc411W7V6/) | [日本語コース](https://br-d.fanbox.cc/posts/5685086?utm_campaign=manage_post_page&utm_medium=share&utm_source=twitter)
83
+
84
+ ### Linux:
85
+
86
+ 1. Installing [anaconda](https://www.anaconda.com/), python and git.
87
+
88
+ 2. Creating the env and install the requirements.
89
+ ```bash
90
+ git clone https://github.com/Winfredy/SadTalker.git
91
+
92
+ cd SadTalker
93
+
94
+ conda create -n sadtalker python=3.8
95
+
96
+ conda activate sadtalker
97
+
98
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
99
+
100
+ conda install ffmpeg
101
+
102
+ pip install -r requirements.txt
103
+
104
+ ### tts is optional for gradio demo.
105
+ ### pip install TTS
106
+
107
+ ```
108
+ ### Windows ([中文windows教程](https://www.bilibili.com/video/BV1Dc411W7V6/)):
109
+
110
+ 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
111
+ 2. Install [git](https://git-scm.com/download/win) manually (OR `scoop install git` via [scoop](https://scoop.sh/)).
112
+ 3. Install `ffmpeg`, following [this instruction](https://www.wikihow.com/Install-FFmpeg-on-Windows) (OR using `scoop install ffmpeg` via [scoop](https://scoop.sh/)).
113
+ 4. Download our SadTalker repository, for example by running `git clone https://github.com/Winfredy/SadTalker.git`.
114
+ 5. Download the `checkpoint` and `gfpgan` [below↓](https://github.com/Winfredy/SadTalker#-2-download-trained-models).
115
+ 5. Run `start.bat` from Windows Explorer as normal, non-administrator, user, a gradio WebUI demo will be started.
116
+
117
+ ### Macbook:
118
+
119
+ More tips about installnation on Macbook and the Docker file can be founded [here](docs/install.md)
120
+
121
+ ## 📥 2. Download Trained Models.
122
+
123
+ You can run the following script to put all the models in the right place.
124
+
125
+ ```bash
126
+ bash scripts/download_models.sh
127
+ ```
128
+
129
+ Other alternatives:
130
+ > we also provide an offline patch (`gfpgan/`), thus, no model will be downloaded when generating.
131
+
132
+ **Google Driver**: download our pre-trained model from [ this link (main checkpoints)](https://drive.google.com/file/d/1gwWh45pF7aelNP_P78uDJL8Sycep-K7j/view?usp=sharing) and [ gfpgan (offline patch)](https://drive.google.com/file/d/19AIBsmfcHW6BRJmeqSFlG5fL445Xmsyi?usp=sharing)
133
+
134
+ **Github Release Page**: download all the files from the [lastest github release page](https://github.com/Winfredy/SadTalker/releases), and then, put it in ./checkpoints.
135
+
136
+ **百度云盘**: we provided the downloaded model in [checkpoints, 提取码: sadt.](https://pan.baidu.com/s/1P4fRgk9gaSutZnn8YW034Q?pwd=sadt) And [gfpgan, 提取码: sadt.](https://pan.baidu.com/s/1kb1BCPaLOWX1JJb9Czbn6w?pwd=sadt)
137
+
138
+
139
+
140
+ <details><summary>Model Details</summary>
141
+
142
+
143
+ Model explains:
144
+
145
+ ##### New version
146
+ | Model | Description
147
+ | :--- | :----------
148
+ |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
149
+ |checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
150
+ |checkpoints/SadTalker_V0.0.2_256.safetensors | packaged sadtalker checkpoints of old version, 256 face render).
151
+ |checkpoints/SadTalker_V0.0.2_512.safetensors | packaged sadtalker checkpoints of old version, 512 face render).
152
+ |gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.
153
+
154
+
155
+ ##### Old version
156
+ | Model | Description
157
+ | :--- | :----------
158
+ |checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
159
+ |checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
160
+ |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
161
+ |checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
162
+ |checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
163
+ |checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
164
+ |checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
165
+ |checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/).
166
+ |checkpoints/BFM | 3DMM library file.
167
+ |checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
168
+ |gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.
169
+
170
+ The final folder will be shown as:
171
+
172
+ <img width="331" alt="image" src="https://user-images.githubusercontent.com/4397546/232511411-4ca75cbf-a434-48c5-9ae0-9009e8316484.png">
173
+
174
+
175
+ </details>
176
+
177
+ ## 🔮 3. Quick Start ([Best Practice](docs/best_practice.md)).
178
+
179
+ ### WebUI Demos:
180
+
181
+ **Online**: [Huggingface](https://huggingface.co/spaces/vinthony/SadTalker) | [SDWebUI-Colab](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) | [Colab](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)
182
+
183
+ **Local Autiomatic1111 stable-diffusion webui extension**: please refer to [Autiomatic1111 stable-diffusion webui docs](docs/webui_extension.md).
184
+
185
+ **Local gradio demo(highly recommanded!)**: Similar to our [hugging-face demo](https://huggingface.co/spaces/vinthony/SadTalker) can be run by:
186
+
187
+ ```bash
188
+ ## you need manually install TTS(https://github.com/coqui-ai/TTS) via `pip install tts` in advanced.
189
+ python app_sadtalker.py
190
+ ```
191
+
192
+ **Local gradio demo(highly recommanded!)**:
193
+
194
+ - windows: just double click `webui.bat`, the requirements will be installed automatically.
195
+ - Linux/Mac OS: run `bash webui.sh` to start the webui.
196
+
197
+
198
+ ### Manually usages:
199
+
200
+ ##### Animating a portrait image from default config:
201
+ ```bash
202
+ python inference.py --driven_audio <audio.wav> \
203
+ --source_image <video.mp4 or picture.png> \
204
+ --enhancer gfpgan
205
+ ```
206
+ The results will be saved in `results/$SOME_TIMESTAMP/*.mp4`.
207
+
208
+ ##### Full body/image Generation:
209
+
210
+ Using `--still` to generate a natural full body video. You can add `enhancer` to improve the quality of the generated video.
211
+
212
+ ```bash
213
+ python inference.py --driven_audio <audio.wav> \
214
+ --source_image <video.mp4 or picture.png> \
215
+ --result_dir <a file to store results> \
216
+ --still \
217
+ --preprocess full \
218
+ --enhancer gfpgan
219
+ ```
220
+
221
+ More examples and configuration and tips can be founded in the [ >>> best practice documents <<<](docs/best_practice.md).
222
+
223
+ ## 🛎 Citation
224
+
225
+ If you find our work useful in your research, please consider citing:
226
+
227
+ ```bibtex
228
+ @article{zhang2022sadtalker,
229
+ title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
230
+ author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
231
+ journal={arXiv preprint arXiv:2211.12194},
232
+ year={2022}
233
+ }
234
+ ```
235
+
236
+
237
+
238
+ ## 💗 Acknowledgements
239
+
240
+ Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, We also use the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.
241
+
242
+ See also these wonderful 3rd libraries we use:
243
+
244
+ - **Face Utils**: https://github.com/xinntao/facexlib
245
+ - **Face Enhancement**: https://github.com/TencentARC/GFPGAN
246
+ - **Image/Video Enhancement**:https://github.com/xinntao/Real-ESRGAN
247
+
248
+ ## 🥂 Extensions:
249
+
250
+ - [SadTalker-Video-Lip-Sync](https://github.com/Zz-ww/SadTalker-Video-Lip-Sync) from [@Zz-ww](https://github.com/Zz-ww): SadTalker for Video Lip Editing
251
+
252
+ ## 🥂 Related Works
253
+ - [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
254
+ - [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
255
+ - [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
256
+ - [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
257
+ - [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
258
+ - [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)
259
+
260
+ ## 📢 Disclaimer
261
+
262
+ This is not an official product of Tencent. This repository can only be used for personal/research/non-commercial purposes.
263
+
264
+ LOGO: color and font suggestion: [ChatGPT](ai.com), logo font:[Montserrat Alternates
265
+ ](https://fonts.google.com/specimen/Montserrat+Alternates?preview.text=SadTalker&preview.text_type=custom&query=mont).
266
+
267
+ All the copyright of the demo images and audio are from communities users or the geneartion from stable diffusion. Free free to contact us if you feel uncomfortable.
268
+
SadTalker/.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ examples/results/*
163
+ gfpgan/*
164
+ checkpoints/*
165
+ assets/*
166
+ results/*
167
+ Dockerfile
168
+ start_docker.sh
169
+ start.sh
170
+
171
+ checkpoints
172
+
173
+ # Mac
174
+ .DS_Store
SadTalker/.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ matplotlib
3
+ moviepy
4
+ yt-dlp
5
+ pydub
6
+ demucs
7
+ gradio
8
+ torch
9
+ flask
10
+ flask-cors
11
+ torchaudio
12
+ fairseq==0.12.2
13
+ scipy==1.10.1
14
+ pyworld>=0.3.2
15
+ faiss-cpu==1.7.3
16
+ praat-parselmouth>=0.4.2
17
+ librosa==0.9.2
18
+ edge-tts
19
+ torchcrepe
20
+ Pillow==9.5.0
21
+
22
+ face_alignment==1.3.5
23
+ imageio==2.19.3
24
+ imageio-ffmpeg==0.4.7
25
+ numba
26
+ resampy==0.3.1
27
+ kornia==0.6.8
28
+ tqdm
29
+ yacs==0.1.8
30
+ pyyaml
31
+ joblib==1.1.0
32
+ scikit-image==0.19.3
33
+ basicsr==1.4.2
34
+ facexlib==0.3.0
35
+ gfpgan
36
+ av
37
+ safetensors
SadTalker/checkpoints/SadTalker_V0.0.2_256.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c211f5d6de003516bf1bbda9f47049a4c9c99133b1ab565c6961e5af16477bff
3
+ size 725066984
SadTalker/checkpoints/SadTalker_V0.0.2_512.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e063f7ff5258240bdb0f7690783a7b1374e6a4a81ce8fa33456f4cd49694340
3
+ size 725066984
SadTalker/checkpoints/mapping_00109-model.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84a8642468a3fcfdd9ab6be955267043116c2bec2284686a5262f1eaf017f64c
3
+ size 155779231
SadTalker/checkpoints/mapping_00229-model.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
3
+ size 155521183
SadTalker/cog.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.3"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "ffmpeg"
7
+ - "libgl1-mesa-glx"
8
+ - "libglib2.0-0"
9
+ python_packages:
10
+ - "torch==1.12.1"
11
+ - "torchvision==0.13.1"
12
+ - "torchaudio==0.12.1"
13
+ - "joblib==1.1.0"
14
+ - "scikit-image==0.19.3"
15
+ - "basicsr==1.4.2"
16
+ - "facexlib==0.3.0"
17
+ - "resampy==0.3.1"
18
+ - "pydub==0.25.1"
19
+ - "scipy==1.10.1"
20
+ - "kornia==0.6.8"
21
+ - "face_alignment==1.3.5"
22
+ - "imageio==2.19.3"
23
+ - "imageio-ffmpeg==0.4.7"
24
+ - "librosa==0.9.2" #
25
+ - "tqdm==4.65.0"
26
+ - "yacs==0.1.8"
27
+ - "gfpgan==1.3.8"
28
+ - "dlib-bin==19.24.1"
29
+ - "av==10.0.0"
30
+ - "trimesh==3.9.20"
31
+ run:
32
+ - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
33
+ - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"
34
+
35
+ predict: "predict.py:Predictor"
SadTalker/gfpgan/weights/GFPGANv1.4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad
3
+ size 348632874
SadTalker/gfpgan/weights/alignment_WFLW_4HG.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbfd137307a4c7debd5c283b9b0ce539466cee417ac0a155e184d857f9f2899c
3
+ size 193670248
SadTalker/gfpgan/weights/detection_Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
SadTalker/gfpgan/weights/parsing_parsenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193
SadTalker/inference.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import shutil
3
+ import torch
4
+ from time import strftime
5
+ import os, sys, time
6
+ from argparse import ArgumentParser
7
+
8
+ from src.utils.preprocess import CropAndExtract
9
+ from src.test_audio2coeff import Audio2Coeff
10
+ from src.facerender.animate import AnimateFromCoeff
11
+ from src.generate_batch import get_data
12
+ from src.generate_facerender_batch import get_facerender_data
13
+ from src.utils.init_path import init_path
14
+
15
+ def main(args):
16
+ #torch.backends.cudnn.enabled = False
17
+
18
+ pic_path = args.source_image
19
+ audio_path = args.driven_audio
20
+ save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
21
+ os.makedirs(save_dir, exist_ok=True)
22
+ pose_style = args.pose_style
23
+ device = args.device
24
+ batch_size = args.batch_size
25
+ input_yaw_list = args.input_yaw
26
+ input_pitch_list = args.input_pitch
27
+ input_roll_list = args.input_roll
28
+ ref_eyeblink = args.ref_eyeblink
29
+ ref_pose = args.ref_pose
30
+
31
+ current_root_path = os.path.split(sys.argv[0])[0]
32
+
33
+ sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
34
+
35
+ #init model
36
+ preprocess_model = CropAndExtract(sadtalker_paths, device)
37
+
38
+ audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
39
+
40
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
41
+
42
+ #crop image and extract 3dmm from image
43
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
44
+ os.makedirs(first_frame_dir, exist_ok=True)
45
+ print('3DMM Extraction for source image')
46
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
47
+ source_image_flag=True, pic_size=args.size)
48
+ if first_coeff_path is None:
49
+ print("Can't get the coeffs of the input")
50
+ return
51
+
52
+ if ref_eyeblink is not None:
53
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
54
+ ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
55
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
56
+ print('3DMM Extraction for the reference video providing eye blinking')
57
+ ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
58
+ else:
59
+ ref_eyeblink_coeff_path=None
60
+
61
+ if ref_pose is not None:
62
+ if ref_pose == ref_eyeblink:
63
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
64
+ else:
65
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
66
+ ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
67
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
68
+ print('3DMM Extraction for the reference video providing pose')
69
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
70
+ else:
71
+ ref_pose_coeff_path=None
72
+
73
+ #audio2ceoff
74
+ batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
75
+ coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
76
+
77
+ # 3dface render
78
+ if args.face3dvis:
79
+ from src.face3d.visualize import gen_composed_video
80
+ gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
81
+
82
+ #coeff2video
83
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
84
+ batch_size, input_yaw_list, input_pitch_list, input_roll_list,
85
+ expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
86
+
87
+ result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
88
+ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
89
+
90
+ shutil.move(result, save_dir+'.mp4')
91
+ print('The generated video is named:', save_dir+'.mp4')
92
+
93
+ if not args.verbose:
94
+ shutil.rmtree(save_dir)
95
+
96
+
97
+ if __name__ == '__main__':
98
+
99
+ parser = ArgumentParser()
100
+ parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
101
+ parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
102
+ parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
103
+ parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
104
+ parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
105
+ parser.add_argument("--result_dir", default='./results', help="path to output")
106
+ parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
107
+ parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
108
+ parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
109
+ parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
110
+ parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
111
+ parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
112
+ parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
113
+ parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
114
+ parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
115
+ parser.add_argument("--cpu", dest="cpu", action="store_true")
116
+ parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
117
+ parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
118
+ parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
119
+ parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
120
+ parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
121
+
122
+
123
+ # net structure and parameters
124
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
125
+ parser.add_argument('--init_path', type=str, default=None, help='Useless')
126
+ parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
127
+ parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
128
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
129
+
130
+ # default renderer parameters
131
+ parser.add_argument('--focal', type=float, default=1015.)
132
+ parser.add_argument('--center', type=float, default=112.)
133
+ parser.add_argument('--camera_d', type=float, default=10.)
134
+ parser.add_argument('--z_near', type=float, default=5.)
135
+ parser.add_argument('--z_far', type=float, default=15.)
136
+
137
+ args = parser.parse_args()
138
+
139
+ if torch.cuda.is_available() and not args.cpu:
140
+ args.device = "cuda"
141
+ else:
142
+ args.device = "cpu"
143
+
144
+ main(args)
145
+
SadTalker/launcher.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ # borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
3
+ import subprocess
4
+ import os
5
+ import sys
6
+ import importlib.util
7
+ import shlex
8
+ import platform
9
+ import json
10
+
11
+ python = sys.executable
12
+ git = os.environ.get('GIT', "git")
13
+ index_url = os.environ.get('INDEX_URL', "")
14
+ stored_commit_hash = None
15
+ skip_install = False
16
+ dir_repos = "repositories"
17
+ script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
18
+
19
+ if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
20
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
21
+
22
+
23
+ def check_python_version():
24
+ is_windows = platform.system() == "Windows"
25
+ major = sys.version_info.major
26
+ minor = sys.version_info.minor
27
+ micro = sys.version_info.micro
28
+
29
+ if is_windows:
30
+ supported_minors = [10]
31
+ else:
32
+ supported_minors = [7, 8, 9, 10, 11]
33
+
34
+ if not (major == 3 and minor in supported_minors):
35
+
36
+ raise (f"""
37
+ INCOMPATIBLE PYTHON VERSION
38
+ This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
39
+ If you encounter an error with "RuntimeError: Couldn't install torch." message,
40
+ or any other error regarding unsuccessful package (library) installation,
41
+ please downgrade (or upgrade) to the latest version of 3.10 Python
42
+ and delete current Python and "venv" folder in WebUI's directory.
43
+ You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
44
+ {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
45
+ Use --skip-python-version-check to suppress this warning.
46
+ """)
47
+
48
+
49
+ def commit_hash():
50
+ global stored_commit_hash
51
+
52
+ if stored_commit_hash is not None:
53
+ return stored_commit_hash
54
+
55
+ try:
56
+ stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
57
+ except Exception:
58
+ stored_commit_hash = "<none>"
59
+
60
+ return stored_commit_hash
61
+
62
+
63
+ def run(command, desc=None, errdesc=None, custom_env=None, live=False):
64
+ if desc is not None:
65
+ print(desc)
66
+
67
+ if live:
68
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
69
+ if result.returncode != 0:
70
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
71
+ Command: {command}
72
+ Error code: {result.returncode}""")
73
+
74
+ return ""
75
+
76
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
77
+
78
+ if result.returncode != 0:
79
+
80
+ message = f"""{errdesc or 'Error running command'}.
81
+ Command: {command}
82
+ Error code: {result.returncode}
83
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
84
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
85
+ """
86
+ raise RuntimeError(message)
87
+
88
+ return result.stdout.decode(encoding="utf8", errors="ignore")
89
+
90
+
91
+ def check_run(command):
92
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
93
+ return result.returncode == 0
94
+
95
+
96
+ def is_installed(package):
97
+ try:
98
+ spec = importlib.util.find_spec(package)
99
+ except ModuleNotFoundError:
100
+ return False
101
+
102
+ return spec is not None
103
+
104
+
105
+ def repo_dir(name):
106
+ return os.path.join(script_path, dir_repos, name)
107
+
108
+
109
+ def run_python(code, desc=None, errdesc=None):
110
+ return run(f'"{python}" -c "{code}"', desc, errdesc)
111
+
112
+
113
+ def run_pip(args, desc=None):
114
+ if skip_install:
115
+ return
116
+
117
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
118
+ return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
119
+
120
+
121
+ def check_run_python(code):
122
+ return check_run(f'"{python}" -c "{code}"')
123
+
124
+
125
+ def git_clone(url, dir, name, commithash=None):
126
+ # TODO clone into temporary dir and move if successful
127
+
128
+ if os.path.exists(dir):
129
+ if commithash is None:
130
+ return
131
+
132
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
133
+ if current_hash == commithash:
134
+ return
135
+
136
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
137
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
138
+ return
139
+
140
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
141
+
142
+ if commithash is not None:
143
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
144
+
145
+
146
+ def git_pull_recursive(dir):
147
+ for subdir, _, _ in os.walk(dir):
148
+ if os.path.exists(os.path.join(subdir, '.git')):
149
+ try:
150
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
151
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
152
+ except subprocess.CalledProcessError as e:
153
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
154
+
155
+
156
+ def run_extension_installer(extension_dir):
157
+ path_installer = os.path.join(extension_dir, "install.py")
158
+ if not os.path.isfile(path_installer):
159
+ return
160
+
161
+ try:
162
+ env = os.environ.copy()
163
+ env['PYTHONPATH'] = os.path.abspath(".")
164
+
165
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
166
+ except Exception as e:
167
+ print(e, file=sys.stderr)
168
+
169
+
170
+ def prepare_environment():
171
+ global skip_install
172
+
173
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
174
+
175
+ ## check windows
176
+ if sys.platform != 'win32':
177
+ requirements_file = os.environ.get('REQS_FILE', "req.txt")
178
+ else:
179
+ requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
180
+
181
+ commit = commit_hash()
182
+
183
+ print(f"Python {sys.version}")
184
+ print(f"Commit hash: {commit}")
185
+
186
+ if not is_installed("torch") or not is_installed("torchvision"):
187
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
188
+
189
+ run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
190
+
191
+ if sys.platform != 'win32' and not is_installed('tts'):
192
+ run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
193
+
194
+
195
+ def start():
196
+ print(f"Launching SadTalker Web UI")
197
+ from app_sadtalker import sadtalker_demo
198
+ demo = sadtalker_demo()
199
+ demo.queue()
200
+ demo.launch()
201
+
202
+ if __name__ == "__main__":
203
+ prepare_environment()
204
+ start()
SadTalker/predict.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """run bash scripts/download_models.sh first to prepare the weights file"""
2
+ import os
3
+ import shutil
4
+ from argparse import Namespace
5
+ from src.utils.preprocess import CropAndExtract
6
+ from src.test_audio2coeff import Audio2Coeff
7
+ from src.facerender.animate import AnimateFromCoeff
8
+ from src.generate_batch import get_data
9
+ from src.generate_facerender_batch import get_facerender_data
10
+ from src.utils.init_path import init_path
11
+ from cog import BasePredictor, Input, Path
12
+
13
+ checkpoints = "checkpoints"
14
+
15
+
16
+ class Predictor(BasePredictor):
17
+ def setup(self):
18
+ """Load the model into memory to make running multiple predictions efficient"""
19
+ device = "cuda"
20
+
21
+
22
+ sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
23
+
24
+ # init model
25
+ self.preprocess_model = CropAndExtract(sadtalker_paths, device
26
+ )
27
+
28
+ self.audio_to_coeff = Audio2Coeff(
29
+ sadtalker_paths,
30
+ device,
31
+ )
32
+
33
+ self.animate_from_coeff = {
34
+ "full": AnimateFromCoeff(
35
+ sadtalker_paths,
36
+ device,
37
+ ),
38
+ "others": AnimateFromCoeff(
39
+ sadtalker_paths,
40
+ device,
41
+ ),
42
+ }
43
+
44
+ def predict(
45
+ self,
46
+ source_image: Path = Input(
47
+ description="Upload the source image, it can be video.mp4 or picture.png",
48
+ ),
49
+ driven_audio: Path = Input(
50
+ description="Upload the driven audio, accepts .wav and .mp4 file",
51
+ ),
52
+ enhancer: str = Input(
53
+ description="Choose a face enhancer",
54
+ choices=["gfpgan", "RestoreFormer"],
55
+ default="gfpgan",
56
+ ),
57
+ preprocess: str = Input(
58
+ description="how to preprocess the images",
59
+ choices=["crop", "resize", "full"],
60
+ default="full",
61
+ ),
62
+ ref_eyeblink: Path = Input(
63
+ description="path to reference video providing eye blinking",
64
+ default=None,
65
+ ),
66
+ ref_pose: Path = Input(
67
+ description="path to reference video providing pose",
68
+ default=None,
69
+ ),
70
+ still: bool = Input(
71
+ description="can crop back to the original videos for the full body aniamtion when preprocess is full",
72
+ default=True,
73
+ ),
74
+ ) -> Path:
75
+ """Run a single prediction on the model"""
76
+
77
+ animate_from_coeff = (
78
+ self.animate_from_coeff["full"]
79
+ if preprocess == "full"
80
+ else self.animate_from_coeff["others"]
81
+ )
82
+
83
+ args = load_default()
84
+ args.pic_path = str(source_image)
85
+ args.audio_path = str(driven_audio)
86
+ device = "cuda"
87
+ args.still = still
88
+ args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
89
+ args.ref_pose = None if ref_pose is None else str(ref_pose)
90
+
91
+ # crop image and extract 3dmm from image
92
+ results_dir = "results"
93
+ if os.path.exists(results_dir):
94
+ shutil.rmtree(results_dir)
95
+ os.makedirs(results_dir)
96
+ first_frame_dir = os.path.join(results_dir, "first_frame_dir")
97
+ os.makedirs(first_frame_dir)
98
+
99
+ print("3DMM Extraction for source image")
100
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
101
+ args.pic_path, first_frame_dir, preprocess, source_image_flag=True
102
+ )
103
+ if first_coeff_path is None:
104
+ print("Can't get the coeffs of the input")
105
+ return
106
+
107
+ if ref_eyeblink is not None:
108
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
109
+ 0
110
+ ]
111
+ ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
112
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
113
+ print("3DMM Extraction for the reference video providing eye blinking")
114
+ ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
115
+ ref_eyeblink, ref_eyeblink_frame_dir
116
+ )
117
+ else:
118
+ ref_eyeblink_coeff_path = None
119
+
120
+ if ref_pose is not None:
121
+ if ref_pose == ref_eyeblink:
122
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
123
+ else:
124
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
125
+ ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
126
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
127
+ print("3DMM Extraction for the reference video providing pose")
128
+ ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
129
+ ref_pose, ref_pose_frame_dir
130
+ )
131
+ else:
132
+ ref_pose_coeff_path = None
133
+
134
+ # audio2ceoff
135
+ batch = get_data(
136
+ first_coeff_path,
137
+ args.audio_path,
138
+ device,
139
+ ref_eyeblink_coeff_path,
140
+ still=still,
141
+ )
142
+ coeff_path = self.audio_to_coeff.generate(
143
+ batch, results_dir, args.pose_style, ref_pose_coeff_path
144
+ )
145
+ # coeff2video
146
+ print("coeff2video")
147
+ data = get_facerender_data(
148
+ coeff_path,
149
+ crop_pic_path,
150
+ first_coeff_path,
151
+ args.audio_path,
152
+ args.batch_size,
153
+ args.input_yaw,
154
+ args.input_pitch,
155
+ args.input_roll,
156
+ expression_scale=args.expression_scale,
157
+ still_mode=still,
158
+ preprocess=preprocess,
159
+ )
160
+ animate_from_coeff.generate(
161
+ data, results_dir, args.pic_path, crop_info,
162
+ enhancer=enhancer, background_enhancer=args.background_enhancer,
163
+ preprocess=preprocess)
164
+
165
+ output = "/tmp/out.mp4"
166
+ mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
167
+ shutil.copy(mp4_path, output)
168
+
169
+ return Path(output)
170
+
171
+
172
+ def load_default():
173
+ return Namespace(
174
+ pose_style=0,
175
+ batch_size=2,
176
+ expression_scale=1.0,
177
+ input_yaw=None,
178
+ input_pitch=None,
179
+ input_roll=None,
180
+ background_enhancer=None,
181
+ face3dvis=False,
182
+ net_recon="resnet50",
183
+ init_path=None,
184
+ use_last_fc=False,
185
+ bfm_folder="./src/config/",
186
+ bfm_model="BFM_model_front.mat",
187
+ focal=1015.0,
188
+ center=112.0,
189
+ camera_d=10.0,
190
+ z_near=5.0,
191
+ z_far=15.0,
192
+ )
SadTalker/req.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llvmlite==0.38.1
2
+ numpy==1.21.6
3
+ face_alignment==1.3.5
4
+ imageio==2.19.3
5
+ imageio-ffmpeg==0.4.7
6
+ librosa==0.10.0.post2
7
+ numba==0.55.1
8
+ resampy==0.3.1
9
+ pydub==0.25.1
10
+ scipy==1.10.1
11
+ kornia==0.6.8
12
+ tqdm
13
+ yacs==0.1.8
14
+ pyyaml
15
+ joblib==1.1.0
16
+ scikit-image==0.19.3
17
+ basicsr==1.4.2
18
+ facexlib==0.3.0
19
+ gradio
20
+ gfpgan
21
+ av
22
+ safetensors
SadTalker/requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ matplotlib
3
+ moviepy
4
+ yt-dlp
5
+ pydub
6
+ demucs
7
+ gradio
8
+ torch
9
+ flask
10
+ flask-cors
11
+ torchaudio
12
+ fairseq==0.12.2
13
+ scipy==1.10.1
14
+ pyworld>=0.3.2
15
+ faiss-cpu==1.7.3
16
+ praat-parselmouth>=0.4.2
17
+ librosa==0.9.2
18
+ edge-tts
19
+ torchcrepe
20
+ Pillow==9.5.0
21
+
22
+ face_alignment==1.3.5
23
+ imageio==2.19.3
24
+ imageio-ffmpeg==0.4.7
25
+ numba
26
+ resampy==0.3.1
27
+ kornia==0.6.8
28
+ tqdm
29
+ yacs==0.1.8
30
+ pyyaml
31
+ joblib==1.1.0
32
+ scikit-image==0.19.3
33
+ basicsr==1.4.2
34
+ facexlib==0.3.0
35
+ gfpgan
36
+ av
37
+ safetensors
SadTalker/scripts/download_models.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mkdir ./checkpoints
2
+
3
+ # lagency download link
4
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
5
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
6
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
7
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
8
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
9
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
10
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
11
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
12
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
13
+ # unzip -n ./checkpoints/hub.zip -d ./checkpoints/
14
+
15
+
16
+ #### download the new links.
17
+ wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
18
+ wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
19
+ wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors
20
+ wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors
21
+
22
+
23
+ # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
24
+ # unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
25
+
26
+ ### enhancer
27
+ mkdir -p ./gfpgan/weights
28
+ wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth
29
+ wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth
30
+ wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth
31
+ wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth
32
+
SadTalker/scripts/extension.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from pathlib import Path
3
+ import tempfile
4
+ import gradio as gr
5
+ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
6
+ from modules.shared import opts, OptionInfo
7
+ from modules import shared, paths, script_callbacks
8
+ import launch
9
+ import glob
10
+ from huggingface_hub import snapshot_download
11
+
12
+
13
+
14
+ def check_all_files_safetensor(current_dir):
15
+ kv = {
16
+ "SadTalker_V0.0.2_256.safetensors": "sadtalker-256",
17
+ "SadTalker_V0.0.2_512.safetensors": "sadtalker-512",
18
+ "mapping_00109-model.pth.tar" : "mapping-109" ,
19
+ "mapping_00229-model.pth.tar" : "mapping-229" ,
20
+ }
21
+
22
+ if not os.path.isdir(current_dir):
23
+ return False
24
+
25
+ dirs = os.listdir(current_dir)
26
+
27
+ for f in dirs:
28
+ if f in kv.keys():
29
+ del kv[f]
30
+
31
+ return len(kv.keys()) == 0
32
+
33
+ def check_all_files(current_dir):
34
+ kv = {
35
+ "auido2exp_00300-model.pth": "audio2exp",
36
+ "auido2pose_00140-model.pth": "audio2pose",
37
+ "epoch_20.pth": "face_recon",
38
+ "facevid2vid_00189-model.pth.tar": "face-render",
39
+ "mapping_00109-model.pth.tar" : "mapping-109" ,
40
+ "mapping_00229-model.pth.tar" : "mapping-229" ,
41
+ "wav2lip.pth": "wav2lip",
42
+ "shape_predictor_68_face_landmarks.dat": "dlib",
43
+ }
44
+
45
+ if not os.path.isdir(current_dir):
46
+ return False
47
+
48
+ dirs = os.listdir(current_dir)
49
+
50
+ for f in dirs:
51
+ if f in kv.keys():
52
+ del kv[f]
53
+
54
+ return len(kv.keys()) == 0
55
+
56
+
57
+
58
+ def download_model(local_dir='./checkpoints'):
59
+ REPO_ID = 'vinthony/SadTalker'
60
+ snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)
61
+
62
+ def get_source_image(image):
63
+ return image
64
+
65
+ def get_img_from_txt2img(x):
66
+ talker_path = Path(paths.script_path) / "outputs"
67
+ imgs_from_txt_dir = str(talker_path / "txt2img-images/")
68
+ imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
69
+ imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
70
+ img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
71
+ return img_from_txt_path, img_from_txt_path
72
+
73
+ def get_img_from_img2img(x):
74
+ talker_path = Path(paths.script_path) / "outputs"
75
+ imgs_from_img_dir = str(talker_path / "img2img-images/")
76
+ imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
77
+ imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
78
+ img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
79
+ return img_from_img_path, img_from_img_path
80
+
81
+ def get_default_checkpoint_path():
82
+ # check the path of models/checkpoints and extensions/
83
+ checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker"
84
+ extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"
85
+
86
+ if check_all_files_safetensor(checkpoint_path):
87
+ # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
88
+ return checkpoint_path
89
+
90
+ if check_all_files_safetensor(extension_checkpoint_path):
91
+ # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
92
+ return extension_checkpoint_path
93
+
94
+ if check_all_files(checkpoint_path):
95
+ # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
96
+ return checkpoint_path
97
+
98
+ if check_all_files(extension_checkpoint_path):
99
+ # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
100
+ return extension_checkpoint_path
101
+
102
+ return None
103
+
104
+
105
+
106
+ def install():
107
+
108
+ kv = {
109
+ "face_alignment": "face-alignment==1.3.5",
110
+ "imageio": "imageio==2.19.3",
111
+ "imageio_ffmpeg": "imageio-ffmpeg==0.4.7",
112
+ "librosa":"librosa==0.8.0",
113
+ "pydub":"pydub==0.25.1",
114
+ "scipy":"scipy==1.8.1",
115
+ "tqdm": "tqdm",
116
+ "yacs":"yacs==0.1.8",
117
+ "yaml": "pyyaml",
118
+ "av":"av",
119
+ "gfpgan": "gfpgan",
120
+ }
121
+
122
+ # # dlib is not necessary currently
123
+ # if 'darwin' in sys.platform:
124
+ # kv['dlib'] = "dlib"
125
+ # else:
126
+ # kv['dlib'] = 'dlib-bin'
127
+
128
+ # #### we need to have a newer version of imageio for our method.
129
+ # launch.run_pip("install imageio==2.19.3", "requirements for SadTalker")
130
+
131
+ for k,v in kv.items():
132
+ if not launch.is_installed(k):
133
+ print(k, launch.is_installed(k))
134
+ launch.run_pip("install "+ v, "requirements for SadTalker")
135
+
136
+ if os.getenv('SADTALKER_CHECKPOINTS'):
137
+ print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
138
+
139
+ elif get_default_checkpoint_path() is not None:
140
+ os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())
141
+ else:
142
+
143
+ print(
144
+ """"
145
+ SadTalker will not support download all the files from hugging face, which will take a long time.
146
+
147
+ please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)
148
+ """
149
+ )
150
+
151
+ # python = sys.executable
152
+
153
+ # launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True)
154
+ # launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)
155
+ # ### run the scripts to downlod models to correct localtion.
156
+ # # print('download models for SadTalker')
157
+ # # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
158
+ # # print('SadTalker is successfully installed!')
159
+ # download_model(paths.script_path+'/extensions/SadTalker/checkpoints')
160
+
161
+
162
+ def on_ui_tabs():
163
+ install()
164
+
165
+ sys.path.extend([paths.script_path+'/extensions/SadTalker'])
166
+
167
+ repo_dir = paths.script_path+'/extensions/SadTalker/'
168
+
169
+ result_dir = opts.sadtalker_result_dir
170
+ os.makedirs(result_dir, exist_ok=True)
171
+
172
+ from app_sadtalker import sadtalker_demo
173
+
174
+ if os.getenv('SADTALKER_CHECKPOINTS'):
175
+ checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
176
+ else:
177
+ checkpoint_path = repo_dir+'checkpoints/'
178
+
179
+ audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)
180
+
181
+ return [(audio_to_video, "SadTalker", "extension")]
182
+
183
+ def on_ui_settings():
184
+ talker_path = Path(paths.script_path) / "outputs"
185
+ section = ('extension', "SadTalker")
186
+ opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section))
187
+
188
+ script_callbacks.on_ui_settings(on_ui_settings)
189
+ script_callbacks.on_ui_tabs(on_ui_tabs)
SadTalker/scripts/test.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ### some test command before commit.
2
+ # python inference.py --preprocess crop --size 256
3
+ # python inference.py --preprocess crop --size 512
4
+
5
+ # python inference.py --preprocess extcrop --size 256
6
+ # python inference.py --preprocess extcrop --size 512
7
+
8
+ # python inference.py --preprocess resize --size 256
9
+ # python inference.py --preprocess resize --size 512
10
+
11
+ # python inference.py --preprocess full --size 256
12
+ # python inference.py --preprocess full --size 512
13
+
14
+ # python inference.py --preprocess extfull --size 256
15
+ # python inference.py --preprocess extfull --size 512
16
+
17
+ python inference.py --preprocess full --size 256 --enhancer gfpgan
18
+ python inference.py --preprocess full --size 512 --enhancer gfpgan
19
+
20
+ python inference.py --preprocess full --size 256 --enhancer gfpgan --still
21
+ python inference.py --preprocess full --size 512 --enhancer gfpgan --still
SadTalker/src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
SadTalker/src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
SadTalker/src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch
SadTalker/src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ # state_dict = self.audio_encoder.state_dict()
47
+
48
+ # for k,v in wav2lip_state_dict.items():
49
+ # if 'audio_encoder' in k:
50
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ # self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
SadTalker/src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
SadTalker/src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
SadTalker/src/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ResidualConv(nn.Module):
6
+ def __init__(self, input_dim, output_dim, stride, padding):
7
+ super(ResidualConv, self).__init__()
8
+
9
+ self.conv_block = nn.Sequential(
10
+ nn.BatchNorm2d(input_dim),
11
+ nn.ReLU(),
12
+ nn.Conv2d(
13
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
+ ),
15
+ nn.BatchNorm2d(output_dim),
16
+ nn.ReLU(),
17
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
+ )
19
+ self.conv_skip = nn.Sequential(
20
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
+ nn.BatchNorm2d(output_dim),
22
+ )
23
+
24
+ def forward(self, x):
25
+
26
+ return self.conv_block(x) + self.conv_skip(x)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, input_dim, output_dim, kernel, stride):
31
+ super(Upsample, self).__init__()
32
+
33
+ self.upsample = nn.ConvTranspose2d(
34
+ input_dim, output_dim, kernel_size=kernel, stride=stride
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.upsample(x)
39
+
40
+
41
+ class Squeeze_Excite_Block(nn.Module):
42
+ def __init__(self, channel, reduction=16):
43
+ super(Squeeze_Excite_Block, self).__init__()
44
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
+ self.fc = nn.Sequential(
46
+ nn.Linear(channel, channel // reduction, bias=False),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(channel // reduction, channel, bias=False),
49
+ nn.Sigmoid(),
50
+ )
51
+
52
+ def forward(self, x):
53
+ b, c, _, _ = x.size()
54
+ y = self.avg_pool(x).view(b, c)
55
+ y = self.fc(y).view(b, c, 1, 1)
56
+ return x * y.expand_as(x)
57
+
58
+
59
+ class ASPP(nn.Module):
60
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
+ super(ASPP, self).__init__()
62
+
63
+ self.aspp_block1 = nn.Sequential(
64
+ nn.Conv2d(
65
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
+ ),
67
+ nn.ReLU(inplace=True),
68
+ nn.BatchNorm2d(out_dims),
69
+ )
70
+ self.aspp_block2 = nn.Sequential(
71
+ nn.Conv2d(
72
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
+ ),
74
+ nn.ReLU(inplace=True),
75
+ nn.BatchNorm2d(out_dims),
76
+ )
77
+ self.aspp_block3 = nn.Sequential(
78
+ nn.Conv2d(
79
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
+ ),
81
+ nn.ReLU(inplace=True),
82
+ nn.BatchNorm2d(out_dims),
83
+ )
84
+
85
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
+ self._init_weights()
87
+
88
+ def forward(self, x):
89
+ x1 = self.aspp_block1(x)
90
+ x2 = self.aspp_block2(x)
91
+ x3 = self.aspp_block3(x)
92
+ out = torch.cat([x1, x2, x3], dim=1)
93
+ return self.output(out)
94
+
95
+ def _init_weights(self):
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class Upsample_(nn.Module):
105
+ def __init__(self, scale=2):
106
+ super(Upsample_, self).__init__()
107
+
108
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
+
110
+ def forward(self, x):
111
+ return self.upsample(x)
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+ def __init__(self, input_encoder, input_decoder, output_dim):
116
+ super(AttentionBlock, self).__init__()
117
+
118
+ self.conv_encoder = nn.Sequential(
119
+ nn.BatchNorm2d(input_encoder),
120
+ nn.ReLU(),
121
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
+ nn.MaxPool2d(2, 2),
123
+ )
124
+
125
+ self.conv_decoder = nn.Sequential(
126
+ nn.BatchNorm2d(input_decoder),
127
+ nn.ReLU(),
128
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
+ )
130
+
131
+ self.conv_attn = nn.Sequential(
132
+ nn.BatchNorm2d(output_dim),
133
+ nn.ReLU(),
134
+ nn.Conv2d(output_dim, 1, 1),
135
+ )
136
+
137
+ def forward(self, x1, x2):
138
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
+ out = self.conv_attn(out)
140
+ return out * x2
SadTalker/src/audio2pose_models/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.audio2pose_models.networks import ResidualConv, Upsample
4
+
5
+
6
+ class ResUnet(nn.Module):
7
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
+ super(ResUnet, self).__init__()
9
+
10
+ self.input_layer = nn.Sequential(
11
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(filters[0]),
13
+ nn.ReLU(),
14
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
+ )
16
+ self.input_skip = nn.Sequential(
17
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
+ )
19
+
20
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
+
23
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
+
25
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
+
28
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
+
31
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
+
34
+ self.output_layer = nn.Sequential(
35
+ nn.Conv2d(filters[0], 1, 1, 1),
36
+ nn.Sigmoid(),
37
+ )
38
+
39
+ def forward(self, x):
40
+ # Encode
41
+ x1 = self.input_layer(x) + self.input_skip(x)
42
+ x2 = self.residual_conv_1(x1)
43
+ x3 = self.residual_conv_2(x2)
44
+ # Bridge
45
+ x4 = self.bridge(x3)
46
+
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
SadTalker/src/config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
SadTalker/src/config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
SadTalker/src/config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
SadTalker/src/config/facerender_still.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 73
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
SadTalker/src/config/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
SadTalker/src/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
SadTalker/src/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
SadTalker/src/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
SadTalker/src/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
SadTalker/src/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
SadTalker/src/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+ class KeypointExtractor():
15
+ def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
+ device=device)
18
+
19
+ def extract_keypoint(self, images, name=None, info=True):
20
+ if isinstance(images, list):
21
+ keypoints = []
22
+ if info:
23
+ i_range = tqdm(images,desc='landmark Det:')
24
+ else:
25
+ i_range = images
26
+
27
+ for image in i_range:
28
+ current_kp = self.extract_keypoint(image)
29
+ if np.mean(current_kp) == -1 and keypoints:
30
+ keypoints.append(keypoints[-1])
31
+ else:
32
+ keypoints.append(current_kp[None])
33
+
34
+ keypoints = np.concatenate(keypoints, 0)
35
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
+ return keypoints
37
+ else:
38
+ while True:
39
+ try:
40
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
+ break
42
+ except RuntimeError as e:
43
+ if str(e).startswith('CUDA'):
44
+ print("Warning: out of memory, sleep for 1s")
45
+ time.sleep(1)
46
+ else:
47
+ print(e)
48
+ break
49
+ except TypeError:
50
+ print('No face detected in this image')
51
+ shape = [68, 2]
52
+ keypoints = -1. * np.ones(shape)
53
+ break
54
+ if name is not None:
55
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
+ return keypoints
57
+
58
+ def read_video(filename):
59
+ frames = []
60
+ cap = cv2.VideoCapture(filename)
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if ret:
64
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ frame = Image.fromarray(frame)
66
+ frames.append(frame)
67
+ else:
68
+ break
69
+ cap.release()
70
+ return frames
71
+
72
+ def run(data):
73
+ filename, opt, device = data
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
75
+ kp_extractor = KeypointExtractor()
76
+ images = read_video(filename)
77
+ name = filename.split('/')[-2:]
78
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
+ kp_extractor.extract_keypoint(
80
+ images,
81
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
82
+ )
83
+
84
+ if __name__ == '__main__':
85
+ set_start_method('spawn')
86
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
+ parser.add_argument('--device_ids', type=str, default='0,1')
90
+ parser.add_argument('--workers', type=int, default=4)
91
+
92
+ opt = parser.parse_args()
93
+ filenames = list()
94
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
+ extensions = VIDEO_EXTENSIONS
97
+
98
+ for ext in extensions:
99
+ os.listdir(f'{opt.input_dir}')
100
+ print(f'{opt.input_dir}/*.{ext}')
101
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
+ print('Total number of videos:', len(filenames))
103
+ pool = Pool(opt.workers)
104
+ args_list = cycle([opt])
105
+ device_ids = opt.device_ids.split(",")
106
+ device_ids = cycle(device_ids)
107
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
+ None
SadTalker/src/face3d/extract_kp_videos_safe.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from torch.multiprocessing import Pool, Process, set_start_method
12
+
13
+ from facexlib.alignment import landmark_98_to_68
14
+ from facexlib.detection import init_detection_model
15
+
16
+ from facexlib.utils import load_file_from_url
17
+ from src.face3d.util.my_awing_arch import FAN
18
+
19
+ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
+ if model_name == 'awing_fan':
21
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
22
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
+ else:
24
+ raise NotImplementedError(f'{model_name} is not implemented.')
25
+
26
+ model_path = load_file_from_url(
27
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
+ model.eval()
30
+ model = model.to(device)
31
+ return model
32
+
33
+
34
+ class KeypointExtractor():
35
+ def __init__(self, device='cuda'):
36
+
37
+ ### gfpgan/weights
38
+ try:
39
+ import webui # in webui
40
+ root_path = 'extensions/SadTalker/gfpgan/weights'
41
+
42
+ except:
43
+ root_path = 'gfpgan/weights'
44
+
45
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
46
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
47
+
48
+ def extract_keypoint(self, images, name=None, info=True):
49
+ if isinstance(images, list):
50
+ keypoints = []
51
+ if info:
52
+ i_range = tqdm(images,desc='landmark Det:')
53
+ else:
54
+ i_range = images
55
+
56
+ for image in i_range:
57
+ current_kp = self.extract_keypoint(image)
58
+ # current_kp = self.detector.get_landmarks(np.array(image))
59
+ if np.mean(current_kp) == -1 and keypoints:
60
+ keypoints.append(keypoints[-1])
61
+ else:
62
+ keypoints.append(current_kp[None])
63
+
64
+ keypoints = np.concatenate(keypoints, 0)
65
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
66
+ return keypoints
67
+ else:
68
+ while True:
69
+ try:
70
+ with torch.no_grad():
71
+ # face detection -> face alignment.
72
+ img = np.array(images)
73
+ bboxes = self.det_net.detect_faces(images, 0.97)
74
+
75
+ bboxes = bboxes[0]
76
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
77
+
78
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
79
+
80
+ #### keypoints to the original location
81
+ keypoints[:,0] += int(bboxes[0])
82
+ keypoints[:,1] += int(bboxes[1])
83
+
84
+ break
85
+ except RuntimeError as e:
86
+ if str(e).startswith('CUDA'):
87
+ print("Warning: out of memory, sleep for 1s")
88
+ time.sleep(1)
89
+ else:
90
+ print(e)
91
+ break
92
+ except TypeError:
93
+ print('No face detected in this image')
94
+ shape = [68, 2]
95
+ keypoints = -1. * np.ones(shape)
96
+ break
97
+ if name is not None:
98
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
99
+ return keypoints
100
+
101
+ def read_video(filename):
102
+ frames = []
103
+ cap = cv2.VideoCapture(filename)
104
+ while cap.isOpened():
105
+ ret, frame = cap.read()
106
+ if ret:
107
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ frame = Image.fromarray(frame)
109
+ frames.append(frame)
110
+ else:
111
+ break
112
+ cap.release()
113
+ return frames
114
+
115
+ def run(data):
116
+ filename, opt, device = data
117
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
118
+ kp_extractor = KeypointExtractor()
119
+ images = read_video(filename)
120
+ name = filename.split('/')[-2:]
121
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
122
+ kp_extractor.extract_keypoint(
123
+ images,
124
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
125
+ )
126
+
127
+ if __name__ == '__main__':
128
+ set_start_method('spawn')
129
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
130
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
131
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
132
+ parser.add_argument('--device_ids', type=str, default='0,1')
133
+ parser.add_argument('--workers', type=int, default=4)
134
+
135
+ opt = parser.parse_args()
136
+ filenames = list()
137
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
138
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
139
+ extensions = VIDEO_EXTENSIONS
140
+
141
+ for ext in extensions:
142
+ os.listdir(f'{opt.input_dir}')
143
+ print(f'{opt.input_dir}/*.{ext}')
144
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
145
+ print('Total number of videos:', len(filenames))
146
+ pool = Pool(opt.workers)
147
+ args_list = cycle([opt])
148
+ device_ids = opt.device_ids.split(",")
149
+ device_ids = cycle(device_ids)
150
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
151
+ None
SadTalker/src/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from src.face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
SadTalker/src/face3d/models/arcface_torch/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
+ identity on a single server.
5
+
6
+ ## Requirements
7
+
8
+ - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
+ - `pip install -r requirements.txt`.
10
+ - Download the dataset
11
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
+ .
13
+
14
+ ## How to Training
15
+
16
+ To train a model, run `train.py` with the path to the configs:
17
+
18
+ ### 1. Single node, 8 GPUs:
19
+
20
+ ```shell
21
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
+ ```
23
+
24
+ ### 2. Multiple nodes, each node 8 GPUs:
25
+
26
+ Node 0:
27
+
28
+ ```shell
29
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
+ ```
31
+
32
+ Node 1:
33
+
34
+ ```shell
35
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3.Training resnet2060 with 8 GPUs:
39
+
40
+ ```shell
41
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
+ ```
43
+
44
+ ## Model Zoo
45
+
46
+ - The models are available for non-commercial research purposes only.
47
+ - All models can be found in here.
48
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
+ - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
+
51
+ ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
+
53
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
+ As the result, we can evaluate the FAIR performance for different algorithms.
56
+
57
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
+
60
+ For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
+ Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
+ There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
+
64
+ | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
+ | :---: | :--- | :--- | :--- |:--- |:--- |
66
+ | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
+ | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
+ | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
+ | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
+ | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
+ | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
+ | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
+ | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
+ | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
+ | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
+
77
+ ### Performance on IJB-C and Verification Datasets
78
+
79
+ | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
+ | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
+ | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
+ | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
+ | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
+ | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
+ | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
+ | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
+ | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
+ | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
+ | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
+
91
+ [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
+
93
+
94
+ ## [Speed Benchmark](docs/speed_benchmark.md)
95
+
96
+ **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
+ classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
+ accuracy with several times faster training performance and smaller GPU memory.
99
+ Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
+ sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
+ sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
+ we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
+ training and mixed precision training.
104
+
105
+ ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
+
107
+ More details see
108
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
+
110
+ ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
+
112
+ `-` means training failed because of gpu memory limitations.
113
+
114
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
+ | :--- | :--- | :--- | :--- |
116
+ |125000 | 4681 | 4824 | 5004 |
117
+ |1400000 | **1672** | 3043 | 4738 |
118
+ |5500000 | **-** | **1389** | 3975 |
119
+ |8000000 | **-** | **-** | 3565 |
120
+ |16000000 | **-** | **-** | 2679 |
121
+ |29000000 | **-** | **-** | **1855** |
122
+
123
+ ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
+
125
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
+ | :--- | :--- | :--- | :--- |
127
+ |125000 | 7358 | 5306 | 4868 |
128
+ |1400000 | 32252 | 11178 | 6056 |
129
+ |5500000 | **-** | 32188 | 9854 |
130
+ |8000000 | **-** | **-** | 12310 |
131
+ |16000000 | **-** | **-** | 19950 |
132
+ |29000000 | **-** | **-** | 32324 |
133
+
134
+ ## Evaluation ICCV2021-MFR and IJB-C
135
+
136
+ More details see [eval.md](docs/eval.md) in docs.
137
+
138
+ ## Test
139
+
140
+ We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
+
142
+ - [x] torch 1.6.0
143
+ - [x] torch 1.7.1
144
+ - [x] torch 1.8.0
145
+ - [x] torch 1.9.0
146
+
147
+ ## Citation
148
+
149
+ ```
150
+ @inproceedings{deng2019arcface,
151
+ title={Arcface: Additive angular margin loss for deep face recognition},
152
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
+ pages={4690--4699},
155
+ year={2019}
156
+ }
157
+ @inproceedings{an2020partical_fc,
158
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
159
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
+ Zhang, Debing and Fu Ying},
161
+ booktitle={Arxiv 2010.05222},
162
+ year={2020}
163
+ }
164
+ ```
SadTalker/src/face3d/models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+ elif name == "mbf":
21
+ fp16 = kwargs.get("fp16", False)
22
+ num_features = kwargs.get("num_features", 512)
23
+ return get_mbf(fp16=fp16, num_features=num_features)
24
+ else:
25
+ raise ValueError()
SadTalker/src/face3d/models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
+ super(IResNet, self).__init__()
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+ self.dilation = 1
69
+ if replace_stride_with_dilation is None:
70
+ replace_stride_with_dilation = [False, False, False]
71
+ if len(replace_stride_with_dilation) != 3:
72
+ raise ValueError("replace_stride_with_dilation should be None "
73
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
+ self.groups = groups
75
+ self.base_width = width_per_group
76
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
+ self.prelu = nn.PReLU(self.inplanes)
79
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
+ self.layer2 = self._make_layer(block,
81
+ 128,
82
+ layers[1],
83
+ stride=2,
84
+ dilate=replace_stride_with_dilation[0])
85
+ self.layer3 = self._make_layer(block,
86
+ 256,
87
+ layers[2],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[1])
90
+ self.layer4 = self._make_layer(block,
91
+ 512,
92
+ layers[3],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[2])
95
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
97
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
+ nn.init.constant_(self.features.weight, 1.0)
100
+ self.features.weight.requires_grad = False
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, 0, 0.1)
105
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
+ nn.init.constant_(m.weight, 1)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ if zero_init_residual:
110
+ for m in self.modules():
111
+ if isinstance(m, IBasicBlock):
112
+ nn.init.constant_(m.bn2.weight, 0)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
+ downsample = None
116
+ previous_dilation = self.dilation
117
+ if dilate:
118
+ self.dilation *= stride
119
+ stride = 1
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ conv1x1(self.inplanes, planes * block.expansion, stride),
123
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
+ )
125
+ layers = []
126
+ layers.append(
127
+ block(self.inplanes, planes, stride, downsample, self.groups,
128
+ self.base_width, previous_dilation))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(
132
+ block(self.inplanes,
133
+ planes,
134
+ groups=self.groups,
135
+ base_width=self.base_width,
136
+ dilation=self.dilation))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ with torch.cuda.amp.autocast(self.fp16):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.prelu(x)
145
+ x = self.layer1(x)
146
+ x = self.layer2(x)
147
+ x = self.layer3(x)
148
+ x = self.layer4(x)
149
+ x = self.bn2(x)
150
+ x = torch.flatten(x, 1)
151
+ x = self.dropout(x)
152
+ x = self.fc(x.float() if self.fp16 else x)
153
+ x = self.features(x)
154
+ return x
155
+
156
+
157
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
+ model = IResNet(block, layers, **kwargs)
159
+ if pretrained:
160
+ raise ValueError()
161
+ return model
162
+
163
+
164
+ def iresnet18(pretrained=False, progress=True, **kwargs):
165
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
+ progress, **kwargs)
167
+
168
+
169
+ def iresnet34(pretrained=False, progress=True, **kwargs):
170
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
+ progress, **kwargs)
172
+
173
+
174
+ def iresnet50(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet100(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet200(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
+ progress, **kwargs)
187
+
SadTalker/src/face3d/models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
SadTalker/src/face3d/models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512):
90
+ super(MobileFaceNet, self).__init__()
91
+ scale = 2
92
+ self.fp16 = fp16
93
+ self.layers = nn.Sequential(
94
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
+ )
103
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
+ self.features = GDC(num_features)
105
+ self._initialize_weights()
106
+
107
+ def _initialize_weights(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.BatchNorm2d):
114
+ m.weight.data.fill_(1)
115
+ m.bias.data.zero_()
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
+ if m.bias is not None:
119
+ m.bias.data.zero_()
120
+
121
+ def forward(self, x):
122
+ with torch.cuda.amp.autocast(self.fp16):
123
+ x = self.layers(x)
124
+ x = self.conv_sep(x.float() if self.fp16 else x)
125
+ x = self.features(x)
126
+ return x
127
+
128
+
129
+ def get_mbf(fp16, num_features):
130
+ return MobileFaceNet(fp16, num_features)
SadTalker/src/face3d/models/arcface_torch/configs/3millions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []