HighCWu commited on
Commit
c68160d
·
1 Parent(s): deb2950

init commit.

Browse files
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ *.db
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 艾梦
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -10,4 +10,5 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
10
  license: mit
11
  ---
12
 
13
+ # beat-interpolator
14
+ Interpolate the latents of your DL model to follow the beat of the music
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import glob
8
+ import pickle
9
+ import sys
10
+ import importlib
11
+ from typing import List, Tuple
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from beat_interpolator import beat_interpolator
19
+
20
+
21
+ def build_models():
22
+ modules = glob.glob('examples/models/*')
23
+ modules = [
24
+ getattr(
25
+ importlib.import_module(
26
+ module.replace('/', '.'),
27
+ package=None
28
+ ),
29
+ 'create'
30
+ )()
31
+ for module in modules
32
+ if '.py' not in module and '__' not in module
33
+ ]
34
+
35
+ attrs = [ (module['name'], module) for module in modules ]
36
+ mnist_idx = -1
37
+ for i in range(len(attrs)):
38
+ name, _ = attrs[i]
39
+ if name == 'MNIST':
40
+ mnist_idx = i
41
+ if mnist_idx > -1:
42
+ mnist_attr = attrs.pop(mnist_idx)
43
+ attrs.insert(0, mnist_attr)
44
+
45
+ return attrs
46
+
47
+
48
+ def parse_args() -> argparse.Namespace:
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument('--device', type=str, default='cpu')
51
+ parser.add_argument('--theme', type=str)
52
+ parser.add_argument('--share', action='store_true')
53
+ parser.add_argument('--port', type=int)
54
+ parser.add_argument('--disable-queue',
55
+ dest='enable_queue',
56
+ action='store_false')
57
+ return parser.parse_args()
58
+
59
+
60
+ def main():
61
+ args = parse_args()
62
+ enable_queue = args.enable_queue
63
+ model_attrs = build_models()
64
+
65
+ with gr.Blocks(theme=args.theme) as demo:
66
+ gr.Markdown('''<center><h1>Beat-Interpolator</h1></center>
67
+ <h2>Play DL models with music beats.</h2><br />
68
+ This is a Gradio Blocks app of <a href="https://github.com/HighCWu/beat-interpolator">HighCWu/beat-interpolator</a>.
69
+ ''')
70
+ with gr.Tabs():
71
+ for name, model_attr in model_attrs:
72
+ with gr.TabItem(name):
73
+ generator = model_attr['generator']
74
+ latent_dim = model_attr['latent_dim']
75
+ default_fps = model_attr['fps']
76
+ max_fps = model_attr['fps'] if enable_queue else 60
77
+ batch_size = model_attr['batch_size']
78
+ strength = model_attr['strength']
79
+ default_max_duration = model_attr['max_duration']
80
+ max_duration = model_attr['max_duration'] if enable_queue else 360
81
+ use_peak = model_attr['use_peak']
82
+
83
+ def build_interpolate(
84
+ generator,
85
+ latent_dim,
86
+ batch_size
87
+ ):
88
+ def interpolate(
89
+ wave_path,
90
+ seed,
91
+ fps=default_fps,
92
+ strength=strength,
93
+ max_duration=default_max_duration,
94
+ use_peak=use_peak):
95
+ return beat_interpolator(
96
+ wave_path,
97
+ generator,
98
+ latent_dim,
99
+ int(seed),
100
+ int(fps),
101
+ batch_size,
102
+ strength,
103
+ max_duration,
104
+ use_peak)
105
+ return interpolate
106
+
107
+ interpolate = build_interpolate(generator, latent_dim, batch_size)
108
+
109
+ with gr.Row():
110
+ with gr.Box():
111
+ with gr.Column():
112
+ with gr.Row():
113
+ wave_in = gr.Audio(
114
+ type="filepath",
115
+ label="Music"
116
+ )
117
+ # wave example not supported currently
118
+ # with gr.Row():
119
+ # example_audios = gr.Dataset(
120
+ # components=[wave_in],
121
+ # samples=[['examples/example.mp3']]
122
+ # )
123
+ # example_audios.click(
124
+ # fn=lambda examples: gr.Audio.update(value=examples[0]),
125
+ # inputs=example_audios,
126
+ # outputs=example_audios.components
127
+ # )
128
+ with gr.Row():
129
+ gr.File(
130
+ value='examples/example.mp3',
131
+ interactive=False,
132
+ label='Example'
133
+ )
134
+ with gr.Row():
135
+ seed_in = gr.Number(
136
+ value=128,
137
+ label='Seed'
138
+ )
139
+ with gr.Row():
140
+ fps_in = gr.Slider(
141
+ value=default_fps,
142
+ minimum=4,
143
+ maximum=max_fps,
144
+ label="FPS"
145
+ )
146
+ with gr.Row():
147
+ strength_in = gr.Slider(
148
+ value=strength,
149
+ maximum=1,
150
+ label="Strength"
151
+ )
152
+ with gr.Row():
153
+ max_duration_in = gr.Slider(
154
+ value=default_max_duration,
155
+ minimum=5,
156
+ maximum=max_duration,
157
+ label="Max Duration"
158
+ )
159
+
160
+ with gr.Row():
161
+ peak_in = gr.Checkbox(value=use_peak, label="Use peak")
162
+
163
+ with gr.Row():
164
+ generate_button = gr.Button('Generate')
165
+
166
+ with gr.Box():
167
+ with gr.Column():
168
+ with gr.Row():
169
+ interpolated_video = gr.Video(label='Output Video')
170
+
171
+
172
+ generate_button.click(interpolate,
173
+ inputs=[
174
+ wave_in,
175
+ seed_in,
176
+ fps_in,
177
+ strength_in,
178
+ max_duration_in,
179
+ peak_in
180
+ ],
181
+ outputs=[interpolated_video])
182
+
183
+ gr.Markdown(
184
+ '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.beat-interpolator" alt="visitor badge"/></center>'
185
+ )
186
+
187
+ demo.launch(
188
+ enable_queue=args.enable_queue,
189
+ server_port=args.port,
190
+ share=args.share,
191
+ )
192
+
193
+
194
+ if __name__ == '__main__':
195
+ main()
beat_interpolator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import gradio as gr
4
+ import soundfile as sf
5
+
6
+ from moviepy.editor import *
7
+
8
+
9
+ cache_wav_path = [f'/tmp/{str(i).zfill(2)}.wav' for i in range(50)]
10
+ wave_path_iter = iter(cache_wav_path)
11
+ cache_mp4_path = [f'/tmp/{str(i).zfill(2)}.mp4' for i in range(50)]
12
+ path_iter = iter(cache_mp4_path)
13
+
14
+ def merge_times(times, times2):
15
+ ids = np.unique(np.where(abs(times2[...,None] - times[None]) < 0.2)[1])
16
+ mask = np.ones_like(times, dtype=np.bool)
17
+ mask[ids] = False
18
+ times = times[mask]
19
+ times = np.concatenate([times, times2])
20
+ times = np.sort(times)
21
+
22
+ return times
23
+
24
+
25
+ def beat_interpolator(wave_path, generator, latent_dim, seed, fps=30, batch_size=1, strength=1, max_duration=None, use_peak=False):
26
+ fps = max(10, fps)
27
+ strength = np.clip(strength, 0, 1)
28
+ hop_length = 512
29
+ y, sr = librosa.load(wave_path, sr=24000)
30
+ duration = librosa.get_duration(y=y, sr=sr)
31
+
32
+ if max_duration is not None:
33
+ y_len = y.shape[0]
34
+ y_idx = int(y_len * max_duration / duration)
35
+ y = y[:y_idx]
36
+
37
+ global wave_path_iter
38
+ try:
39
+ wave_path = next(wave_path_iter)
40
+ except:
41
+ wave_path_iter = iter(cache_wav_path)
42
+ wave_path = next(wave_path_iter)
43
+ sf.write(wave_path, y, sr, subtype='PCM_24')
44
+ y, sr = librosa.load(wave_path, sr=24000)
45
+ duration = librosa.get_duration(y=y, sr=sr)
46
+
47
+ S = np.abs(librosa.stft(y))
48
+ db = librosa.power_to_db(S**2, ref=np.median).max(0)
49
+ db_mean = np.mean(db)
50
+ db_max = np.max(db)
51
+ db_min = np.min(db)
52
+ db_times = librosa.frames_to_time(np.arange(len(db)), sr=sr, hop_length=hop_length)
53
+ rng = np.random.RandomState(seed)
54
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=512, aggregate=np.median)
55
+ _, beats = librosa.beat.beat_track(y=y, sr=sr, onset_envelope=onset_env, hop_length=512, units='time')
56
+ times = np.asarray(beats)
57
+ if use_peak:
58
+ peaks = librosa.util.peak_pick(onset_env, 1, 1, 1, 1, 0.8, 5)
59
+ times2 = librosa.frames_to_time(np.arange(len(onset_env)), sr=sr, hop_length=512)[peaks]
60
+ times2 = np.asarray(times)
61
+ times = merge_times(times, times2)
62
+
63
+ times = np.concatenate([np.asarray([0.]), times], 0)
64
+ times = list(np.unique(np.int64(np.floor(times * fps / 2))) * 2)
65
+
66
+ latents = []
67
+ time0 = 0
68
+ latent0 = rng.randn(latent_dim)
69
+ for time1 in times:
70
+ latent1 = rng.randn(latent_dim)
71
+ db_cur_index = np.argmin(np.abs(db_times - time1.astype('float32') / fps))
72
+ db_cur = db[db_cur_index]
73
+ if db_cur < db_min + (db_mean - db_min) / 3:
74
+ latent1 = latent0 * 0.8 + latent1 * 0.2
75
+ elif db_cur < db_min + 2 * (db_mean - db_min) / 3:
76
+ latent1 = latent0 * 0.6 + latent1 * 0.4
77
+ elif db_cur < db_mean + (db_max - db_mean) / 3:
78
+ latent1 = latent0 * 0.4 + latent1 * 0.6
79
+ elif db_cur < db_mean + 2 * (db_max - db_mean) / 3:
80
+ latent1 = latent0 * 0.2 + latent1 * 0.8
81
+ else:
82
+ pass
83
+ if time1 > duration * fps:
84
+ time1 = int(duration * fps)
85
+ t1 = time1 - time0
86
+ alpha = 0.5 * strength
87
+ latent2 = latent0 * alpha + latent1 * (1 - alpha)
88
+ for j in range(t1):
89
+ alpha = j / t1
90
+ latent = latent0 * (1 - alpha) + latent2 * alpha
91
+ latents.append(latent)
92
+
93
+ time0 = time1
94
+ latent0 = latent1
95
+
96
+ outs = []
97
+ ix = 0
98
+ while True:
99
+ if ix + batch_size <= len(latents):
100
+ outs += generator(latents[ix:ix+batch_size])
101
+ elif ix < len(latents):
102
+ outs += generator(latents[ix:])
103
+ break
104
+ else:
105
+ break
106
+ ix += batch_size
107
+
108
+ global path_iter
109
+ try:
110
+ video_path = next(path_iter)
111
+ except:
112
+ path_iter = iter(cache_mp4_path)
113
+ video_path = next(path_iter)
114
+
115
+ video = ImageSequenceClip(outs, fps=fps)
116
+ audioclip = AudioFileClip(wave_path)
117
+
118
+ video = video.set_audio(audioclip)
119
+ video.write_videofile(video_path, fps=fps)
120
+
121
+ return video_path
examples/__init__.py ADDED
File without changes
examples/example.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8afffc71afc7b665cf52c5425a85db533bc4b4b0ea878a6812bcb2a99941e5a3
3
+ size 962186
examples/models/__init__.py ADDED
File without changes
examples/models/anime_biggan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import create_anime_biggan_inference as create
examples/models/anime_biggan/model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter
5
+ from torch.nn import functional as F
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ def l2_normalize(v, dim=None, eps=1e-12):
10
+ return v / (v.norm(dim=dim, keepdim=True) + eps)
11
+
12
+
13
+ def unpool(value):
14
+ """Unpooling operation.
15
+ N-dimensional version of the unpooling operation from
16
+ https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf
17
+ Taken from: https://github.com/tensorflow/tensorflow/issues/2169
18
+ Args:
19
+ value: a Tensor of shape [b, d0, d1, ..., dn, ch]
20
+ name: name of the op
21
+ Returns:
22
+ A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch]
23
+ """
24
+ value = torch.Tensor.permute(value, [0,2,3,1])
25
+ sh = list(value.shape)
26
+ dim = len(sh[1:-1])
27
+ out = (torch.reshape(value, [-1] + sh[-dim:]))
28
+ for i in range(dim, 0, -1):
29
+ out = torch.cat([out, torch.zeros_like(out)], i)
30
+ out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]]
31
+ out = torch.reshape(out, out_size)
32
+ out = torch.Tensor.permute(out, [0,3,1,2])
33
+ return out
34
+
35
+
36
+ class BatchNorm2d(nn.BatchNorm2d):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self.initialized = False
40
+ self.accumulating = False
41
+ self.accumulated_mean = Parameter(torch.zeros(args[0]), requires_grad=False)
42
+ self.accumulated_var = Parameter(torch.zeros(args[0]), requires_grad=False)
43
+ self.accumulated_counter = Parameter(torch.zeros(1)+1e-12, requires_grad=False)
44
+
45
+ def forward(self, inputs, *args, **kwargs):
46
+ if not self.initialized:
47
+ self.check_accumulation()
48
+ self.set_initialized(True)
49
+ if self.accumulating:
50
+ self.eval()
51
+ with torch.no_grad():
52
+ axes = [0] + ([] if len(inputs.shape) == 2 else list(range(2,len(inputs.shape))))
53
+ _mean = torch.mean(inputs, axes, keepdim=True)
54
+ mean = torch.mean(inputs, axes, keepdim=False)
55
+ var = torch.mean((inputs-_mean)**2, axes)
56
+ self.accumulated_mean.copy_(self.accumulated_mean + mean)
57
+ self.accumulated_var.copy_(self.accumulated_var + var)
58
+ self.accumulated_counter.copy_(self.accumulated_counter + 1)
59
+ _mean = self.running_mean*1.0
60
+ _variance = self.running_var*1.0
61
+ self._mean.copy_(self.accumulated_mean / self.accumulated_counter)
62
+ self._variance.copy_(self.accumulated_var / self.accumulated_counter)
63
+ out = super().forward(inputs, *args, **kwargs)
64
+ self.running_mean.copy_(_mean)
65
+ self.running_var.copy_(_variance)
66
+ return out
67
+ out = super().forward(inputs, *args, **kwargs)
68
+ return out
69
+
70
+ def check_accumulation(self):
71
+ if self.accumulated_counter.detach().cpu().numpy().mean() > 1-1e-12:
72
+ self.running_mean.copy_(self.accumulated_mean / self.accumulated_counter)
73
+ self.running_var.copy_(self.accumulated_var / self.accumulated_counter)
74
+ return True
75
+ return False
76
+
77
+ def clear_accumulated(self):
78
+ self.accumulated_mean.copy_(self.accumulated_mean*0.0)
79
+ self.accumulated_var.copy_(self.accumulated_var*0.0)
80
+ self.accumulated_counter.copy_(self.accumulated_counter*0.0+1e-2)
81
+
82
+ def set_accumulating(self, status=True):
83
+ if status:
84
+ self.accumulating = True
85
+ else:
86
+ self.accumulating = False
87
+
88
+ def set_initialized(self, status=False):
89
+ if not status:
90
+ self.initialized = False
91
+ else:
92
+ self.initialized = True
93
+
94
+
95
+ class SpectralNorm(nn.Module):
96
+ def __init__(self, module, name='weight', power_iterations=2):
97
+ super().__init__()
98
+ self.module = module
99
+ self.name = name
100
+ self.power_iterations = power_iterations
101
+ if not self._made_params():
102
+ self._make_params()
103
+
104
+ def _update_u(self):
105
+ w = self.weight
106
+ u = self.weight_u
107
+
108
+ if len(w.shape) == 4:
109
+ _w = torch.Tensor.permute(w, [2,3,1,0])
110
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
111
+ elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
112
+ _w = torch.Tensor.permute(w, [1,0])
113
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
114
+ else:
115
+ _w = torch.reshape(w, [-1, w.shape[-1]])
116
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
117
+ singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
118
+ norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
119
+ for _ in range(self.power_iterations):
120
+ if singular_value == "left":
121
+ v = l2_normalize(torch.matmul(_w.t(), u), dim=norm_dim)
122
+ u = l2_normalize(torch.matmul(_w, v), dim=norm_dim)
123
+ else:
124
+ v = l2_normalize(torch.matmul(u, _w.t()), dim=norm_dim)
125
+ u = l2_normalize(torch.matmul(v, _w), dim=norm_dim)
126
+
127
+ if singular_value == "left":
128
+ sigma = torch.matmul(torch.matmul(u.t(), _w), v)
129
+ else:
130
+ sigma = torch.matmul(torch.matmul(v, _w), u.t())
131
+ _w = w / sigma.detach()
132
+ setattr(self.module, self.name, _w)
133
+ self.weight_u.copy_(u.detach())
134
+
135
+ def _made_params(self):
136
+ try:
137
+ self.weight
138
+ self.weight_u
139
+ return True
140
+ except AttributeError:
141
+ return False
142
+
143
+ def _make_params(self):
144
+ w = getattr(self.module, self.name)
145
+
146
+ if len(w.shape) == 4:
147
+ _w = torch.Tensor.permute(w, [2,3,1,0])
148
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
149
+ elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
150
+ _w = torch.Tensor.permute(w, [1,0])
151
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
152
+ else:
153
+ _w = torch.reshape(w, [-1, w.shape[-1]])
154
+ singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
155
+ norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
156
+ u_shape = (_w.shape[0], 1) if singular_value == "left" else (1, _w.shape[-1])
157
+
158
+ u = Parameter(w.data.new(*u_shape).normal_(0, 1), requires_grad=False)
159
+ u.copy_(l2_normalize(u, dim=norm_dim).detach())
160
+
161
+ del self.module._parameters[self.name]
162
+ self.weight = w
163
+ self.weight_u = u
164
+
165
+ def forward(self, *args, **kwargs):
166
+ self._update_u()
167
+ return self.module.forward(*args, **kwargs)
168
+
169
+
170
+ class SelfAttention(nn.Module):
171
+ def __init__(self, in_dim, activation=torch.relu):
172
+ super().__init__()
173
+ self.chanel_in = in_dim
174
+ self.activation = activation
175
+
176
+ self.theta = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
177
+ self.phi = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
178
+ self.pool = nn.MaxPool2d(2, 2)
179
+ self.g = SpectralNorm(nn.Conv2d(in_dim, in_dim // 2, 1, bias=False))
180
+ self.o_conv = SpectralNorm(nn.Conv2d(in_dim // 2, in_dim, 1, bias=False))
181
+ self.gamma = Parameter(torch.zeros(1))
182
+
183
+ def forward(self, x):
184
+ m_batchsize, C, width, height = x.shape
185
+ N = height * width
186
+
187
+ theta = self.theta(x)
188
+ phi = self.phi(x)
189
+ phi = self.pool(phi)
190
+ phi = torch.reshape(phi,(m_batchsize, -1, N // 4))
191
+ theta = torch.reshape(theta,(m_batchsize, -1, N))
192
+ theta = torch.Tensor.permute(theta,(0, 2, 1))
193
+ attention = torch.softmax(torch.bmm(theta, phi), -1)
194
+ g = self.g(x)
195
+ g = torch.reshape(self.pool(g),(m_batchsize, -1, N // 4))
196
+ attn_g = torch.reshape(torch.bmm(g, torch.Tensor.permute(attention,(0, 2, 1))),(m_batchsize, -1, width, height))
197
+ out = self.o_conv(attn_g)
198
+ return self.gamma * out + x
199
+
200
+
201
+ class ConditionalBatchNorm2d(nn.Module):
202
+ def __init__(self, num_features, num_classes, eps=1e-5, momentum=0.1):
203
+ super().__init__()
204
+ self.bn_in_cond = BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
205
+ self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
206
+ self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
207
+
208
+ def forward(self, x, y):
209
+ out = self.bn_in_cond(x)
210
+
211
+ if isinstance(y, list):
212
+ gamma, beta = y
213
+ out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
214
+ return out
215
+
216
+ gamma = self.gamma_embed(y)
217
+ # gamma = gamma + 1
218
+ beta = self.beta_embed(y)
219
+ out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
220
+ return out
221
+
222
+
223
+ class ResBlock(nn.Module):
224
+ def __init__(
225
+ self,
226
+ in_channel,
227
+ out_channel,
228
+ kernel_size=[3, 3],
229
+ padding=1,
230
+ stride=1,
231
+ n_class=None,
232
+ conditional=True,
233
+ activation=torch.relu,
234
+ upsample=True,
235
+ downsample=False,
236
+ z_dim=128,
237
+ use_attention=False,
238
+ skip_proj=None
239
+ ):
240
+ super().__init__()
241
+
242
+ if conditional:
243
+ self.cond_norm1 = ConditionalBatchNorm2d(in_channel, z_dim)
244
+
245
+ self.conv0 = SpectralNorm(
246
+ nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)
247
+ )
248
+
249
+ if conditional:
250
+ self.cond_norm2 = ConditionalBatchNorm2d(out_channel, z_dim)
251
+
252
+ self.conv1 = SpectralNorm(
253
+ nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding)
254
+ )
255
+
256
+ self.skip_proj = False
257
+ if skip_proj is not True and (upsample or downsample):
258
+ self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
259
+ self.skip_proj = True
260
+
261
+ if use_attention:
262
+ self.attention = SelfAttention(out_channel)
263
+
264
+ self.upsample = upsample
265
+ self.downsample = downsample
266
+ self.activation = activation
267
+ self.conditional = conditional
268
+ self.use_attention = use_attention
269
+
270
+ def forward(self, input, condition=None):
271
+ out = input
272
+
273
+ if self.conditional:
274
+ out = self.cond_norm1(out, condition if not isinstance(condition, list) else condition[0])
275
+ out = self.activation(out)
276
+ if self.upsample:
277
+ out = unpool(out) # out = F.interpolate(out, scale_factor=2)
278
+ out = self.conv0(out)
279
+ if self.conditional:
280
+ out = self.cond_norm2(out, condition if not isinstance(condition, list) else condition[1])
281
+ out = self.activation(out)
282
+ out = self.conv1(out)
283
+
284
+ if self.downsample:
285
+ out = F.avg_pool2d(out, 2, 2)
286
+
287
+ if self.skip_proj:
288
+ skip = input
289
+ if self.upsample:
290
+ skip = unpool(skip) # skip = F.interpolate(skip, scale_factor=2)
291
+ skip = self.conv_sc(skip)
292
+ if self.downsample:
293
+ skip = F.avg_pool2d(skip, 2, 2)
294
+ out = out + skip
295
+ else:
296
+ skip = input
297
+
298
+ if self.use_attention:
299
+ out = self.attention(out)
300
+
301
+ return out
302
+
303
+
304
+ class Generator(nn.Module):
305
+ def __init__(self, code_dim=128, n_class=1000, chn=96, blocks_with_attention="B4", resolution=512):
306
+ super().__init__()
307
+
308
+ def GBlock(in_channel, out_channel, n_class, z_dim, use_attention):
309
+ return ResBlock(in_channel, out_channel, n_class=n_class, z_dim=z_dim, use_attention=use_attention)
310
+
311
+ self.embed_y = nn.Linear(n_class, 128, bias=False)
312
+
313
+ self.chn = chn
314
+ self.resolution = resolution
315
+ self.blocks_with_attention = set(blocks_with_attention.split(","))
316
+ self.blocks_with_attention.discard('')
317
+
318
+ gblock = []
319
+ in_channels, out_channels = self.get_in_out_channels()
320
+ self.num_split = len(in_channels) + 1
321
+
322
+ z_dim = code_dim//self.num_split + 128
323
+ self.noise_fc = SpectralNorm(nn.Linear(code_dim//self.num_split, 4 * 4 * in_channels[0]))
324
+
325
+ self.sa_ids = [int(s.split('B')[-1]) for s in self.blocks_with_attention]
326
+
327
+ for i, (nc_in, nc_out) in enumerate(zip(in_channels, out_channels)):
328
+ gblock.append(GBlock(nc_in, nc_out, n_class=n_class, z_dim=z_dim, use_attention=(i+1) in self.sa_ids))
329
+ self.blocks = nn.ModuleList(gblock)
330
+
331
+ self.output_layer_bn = BatchNorm2d(1 * chn, eps=1e-5)
332
+ self.output_layer_conv = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
333
+
334
+ self.z_dim = code_dim
335
+ self.c_dim = n_class
336
+ self.n_level = self.num_split
337
+
338
+ def get_in_out_channels(self):
339
+ resolution = self.resolution
340
+ if resolution == 1024:
341
+ channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1, 1]
342
+ elif resolution == 512:
343
+ channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1]
344
+ elif resolution == 256:
345
+ channel_multipliers = [16, 16, 8, 8, 4, 2, 1]
346
+ elif resolution == 128:
347
+ channel_multipliers = [16, 16, 8, 4, 2, 1]
348
+ elif resolution == 64:
349
+ channel_multipliers = [16, 16, 8, 4, 2]
350
+ elif resolution == 32:
351
+ channel_multipliers = [4, 4, 4, 4]
352
+ else:
353
+ raise ValueError("Unsupported resolution: {}".format(resolution))
354
+ in_channels = [self.chn * c for c in channel_multipliers[:-1]]
355
+ out_channels = [self.chn * c for c in channel_multipliers[1:]]
356
+ return in_channels, out_channels
357
+
358
+ def forward(self, input, class_id):
359
+ codes = torch.chunk(input, self.num_split, 1)
360
+ class_emb = self.embed_y(class_id) # 128
361
+ out = self.noise_fc(codes[0])
362
+ out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
363
+ for i, (code, gblock) in enumerate(zip(codes[1:], self.blocks)):
364
+ condition = torch.cat([code, class_emb], 1)
365
+ out = gblock(out, condition)
366
+
367
+ out = self.output_layer_bn(out)
368
+ out = torch.relu(out)
369
+ out = self.output_layer_conv(out)
370
+
371
+ return (torch.tanh(out) + 1) / 2
372
+
373
+ def forward_w(self, ws):
374
+ out = self.noise_fc(ws[0])
375
+ out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
376
+ for i, (w, gblock) in enumerate(zip(ws[1:], self.blocks)):
377
+ out = gblock(out, w)
378
+
379
+ out = self.output_layer_bn(out)
380
+ out = torch.relu(out)
381
+ out = self.output_layer_conv(out)
382
+
383
+ return (torch.tanh(out) + 1) / 2
384
+
385
+ def forward_wp(self, z0, gammas, betas):
386
+ out = self.noise_fc(z0)
387
+ out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
388
+ for i, (gamma, beta, gblock) in enumerate(zip(gammas, betas, self.blocks)):
389
+ out = gblock(out, [[gamma[0], beta[0]], [gamma[1], beta[1]]])
390
+
391
+ out = self.output_layer_bn(out)
392
+ out = torch.relu(out)
393
+ out = self.output_layer_conv(out)
394
+
395
+ return (torch.tanh(out) + 1) / 2
396
+
397
+
398
+
399
+ def create_anime_biggan_inference():
400
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
401
+ anime_biggan = Generator(
402
+ code_dim=140, n_class=1000, chn=96,
403
+ blocks_with_attention="B5", resolution=256
404
+ )
405
+ state = torch.load(
406
+ hf_hub_download('HighCWu/anime-biggan-pytorch',
407
+ f'pytorch_model.bin'),
408
+ map_location='cpu'
409
+ )
410
+ anime_biggan.load_state_dict(state)
411
+ anime_biggan.to(device)
412
+ anime_biggan.eval()
413
+
414
+ @torch.inference_mode()
415
+ def anime_biggan_generator(latents):
416
+ latents = [torch.from_numpy(latent).float().to(device) for latent in latents]
417
+ latents = torch.stack(latents)
418
+ label = torch.zeros([latents.shape[0], anime_biggan.c_dim], device=device)
419
+ label[:,0] = 1
420
+ out = anime_biggan(latents, label)
421
+ outs = []
422
+ for out_i in out:
423
+ out_i = (out_i.permute(1,2,0) * 255).clamp(0,255).cpu().numpy()
424
+ out_i = np.uint8(out_i)
425
+ outs.append(out_i)
426
+ return outs
427
+
428
+ return {
429
+ 'name': 'Anime Biggan',
430
+ 'generator': anime_biggan_generator,
431
+ 'latent_dim': anime_biggan.z_dim,
432
+ 'fps': 5,
433
+ 'batch_size': 1,
434
+ 'strength': 0.45,
435
+ 'max_duration': 15,
436
+ 'use_peak': True
437
+ }
examples/models/celeba256/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import create_celeba256_inference as create
examples/models/celeba256/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def create_celeba256_inference():
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+ use_gpu = True if torch.cuda.is_available() else False
8
+ celeba256 = torch.hub.load(
9
+ 'facebookresearch/pytorch_GAN_zoo:hub',
10
+ 'PGAN',
11
+ model_name='celebAHQ-256',
12
+ pretrained=True,
13
+ useGPU=use_gpu
14
+ )
15
+ celeba256_noise, _ = celeba256.buildNoiseData(1)
16
+ @torch.inference_mode()
17
+ def celeba256_generator(latents):
18
+ latents = [torch.from_numpy(latent).float().to(device) for latent in latents]
19
+ latents = torch.stack(latents)
20
+ out = celeba256.test(latents)
21
+ outs = []
22
+ for out_i in out:
23
+ out_i = ((out_i.permute(1,2,0) + 1) * 127.5).clamp(0,255).cpu().numpy()
24
+ out_i = np.uint8(out_i)
25
+ outs.append(out_i)
26
+ return outs
27
+
28
+ return {
29
+ 'name': 'Celeba256',
30
+ 'generator': celeba256_generator,
31
+ 'latent_dim': celeba256_noise.shape[1],
32
+ 'fps': 5,
33
+ 'batch_size': 1,
34
+ 'strength': 0.6,
35
+ 'max_duration': 20,
36
+ 'use_peak': True
37
+ }
examples/models/fashion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import create_fashion_inference as create
examples/models/fashion/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def create_fashion_inference():
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+ use_gpu = True if torch.cuda.is_available() else False
8
+ fashion = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=use_gpu)
9
+ fashion_noise, _ = fashion.buildNoiseData(1)
10
+ @torch.inference_mode()
11
+ def fashion_generator(latents):
12
+ latents = [torch.from_numpy(latent).float().to(device) for latent in latents]
13
+ latents = torch.stack(latents)
14
+ out = fashion.test(latents)
15
+ outs = []
16
+ for out_i in out:
17
+ out_i = ((out_i.permute(1,2,0) + 1) * 127.5).clamp(0,255).cpu().numpy()
18
+ out_i = np.uint8(out_i)
19
+ outs.append(out_i)
20
+ return outs
21
+
22
+ return {
23
+ 'name': 'Fashion',
24
+ 'generator': fashion_generator,
25
+ 'latent_dim': fashion_noise.shape[1],
26
+ 'fps': 15,
27
+ 'batch_size': 8,
28
+ 'strength': 0.6,
29
+ 'max_duration': 30,
30
+ 'use_peak': True
31
+ }
examples/models/mnist/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import create_mnist_inference as create
examples/models/mnist/mnist_generator.pretrained ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f6628c922425612cf21f48ed3325310c51441b279a86296fd0fa7041451296b
3
+ size 2268434
examples/models/mnist/model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Generator(nn.Module):
8
+ '''Refer to https://github.com/safwankdb/Vanilla-GAN'''
9
+ def __init__(self):
10
+ super(Generator, self).__init__()
11
+ self.n_features = 128
12
+ self.n_out = 784
13
+ self.fc0 = nn.Sequential(
14
+ nn.Linear(self.n_features, 256),
15
+ nn.LeakyReLU(0.2)
16
+ )
17
+ self.fc1 = nn.Sequential(
18
+ nn.Linear(256, 512),
19
+ nn.LeakyReLU(0.2)
20
+ )
21
+ self.fc2 = nn.Sequential(
22
+ nn.Linear(512, 784),
23
+ nn.Tanh()
24
+ )
25
+ def forward(self, x):
26
+ x = self.fc0(x)
27
+ x = self.fc1(x)
28
+ x = self.fc2(x)
29
+ x = x.view(-1, 1, 28, 28)
30
+ return x
31
+
32
+
33
+ def create_mnist_inference():
34
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+ mnist = Generator()
36
+ state = torch.load(
37
+ os.path.join(
38
+ os.path.dirname(__file__),
39
+ 'mnist_generator.pretrained'
40
+ ),
41
+ map_location='cpu'
42
+ )
43
+ mnist.load_state_dict(state)
44
+ mnist.to(device)
45
+ mnist.eval()
46
+
47
+ @torch.inference_mode()
48
+ def mnist_generator(latents):
49
+ latents = [torch.from_numpy(latent).float().to(device) for latent in latents]
50
+ latents = torch.stack(latents)
51
+ out = mnist(latents)
52
+ outs = []
53
+ for out_i in out:
54
+ out_i = ((out_i[0] + 1) * 127.5).clamp(0,255).cpu().numpy()
55
+ out_i = np.uint8(out_i)
56
+ out_i = np.stack([out_i]*3, -1)
57
+ outs.append(out_i)
58
+ return outs
59
+
60
+ return {
61
+ 'name': 'MNIST',
62
+ 'generator': mnist_generator,
63
+ 'latent_dim': 128,
64
+ 'fps': 20,
65
+ 'batch_size': 8,
66
+ 'strength': 0.75,
67
+ 'max_duration': 30,
68
+ 'use_peak': True
69
+ }
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ liblzma-dev
2
+ libsndfile1
3
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==3.0.4
2
+ huggingface-hub==0.6.0
3
+ moviepy==1.0.3
4
+ Pillow==9.0.1
5
+ torch==1.11.0
6
+ torchvision==0.12.0
7
+ librosa
8
+ soundfile