Added code
Browse files- LoRA dataset/Training script/.ipynb_checkpoints/training_script-checkpoint.ipynb +33 -0
- LoRA dataset/Training script/training_script.ipynb +33 -0
- LoRA dataset/Weights/.gitattributes +1 -0
- LoRA dataset/Weights/pytorch_lora_weights.safetensors +3 -0
- MusicCaps/.gitattributes +1 -0
- MusicCaps/__pycache__/audio_utils.cpython-310.pyc +0 -0
- MusicCaps/__pycache__/bart.cpython-310.pyc +0 -0
- MusicCaps/__pycache__/modules.cpython-310.pyc +0 -0
- MusicCaps/audio_utils.py +245 -0
- MusicCaps/bart.py +151 -0
- MusicCaps/modules.py +95 -0
- MusicCaps/train_model.py +32 -0
- MusicCaps/transfer.pth +3 -0
- app.py +125 -0
- requirements.txt +0 -0
LoRA dataset/Training script/.ipynb_checkpoints/training_script-checkpoint.ipynb
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "3b9bbec2",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": []
|
10 |
+
}
|
11 |
+
],
|
12 |
+
"metadata": {
|
13 |
+
"kernelspec": {
|
14 |
+
"display_name": "Python 3 (ipykernel)",
|
15 |
+
"language": "python",
|
16 |
+
"name": "python3"
|
17 |
+
},
|
18 |
+
"language_info": {
|
19 |
+
"codemirror_mode": {
|
20 |
+
"name": "ipython",
|
21 |
+
"version": 3
|
22 |
+
},
|
23 |
+
"file_extension": ".py",
|
24 |
+
"mimetype": "text/x-python",
|
25 |
+
"name": "python",
|
26 |
+
"nbconvert_exporter": "python",
|
27 |
+
"pygments_lexer": "ipython3",
|
28 |
+
"version": "3.7.16"
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"nbformat": 4,
|
32 |
+
"nbformat_minor": 5
|
33 |
+
}
|
LoRA dataset/Training script/training_script.ipynb
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "3b9bbec2",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": []
|
10 |
+
}
|
11 |
+
],
|
12 |
+
"metadata": {
|
13 |
+
"kernelspec": {
|
14 |
+
"display_name": "Python 3 (ipykernel)",
|
15 |
+
"language": "python",
|
16 |
+
"name": "python3"
|
17 |
+
},
|
18 |
+
"language_info": {
|
19 |
+
"codemirror_mode": {
|
20 |
+
"name": "ipython",
|
21 |
+
"version": 3
|
22 |
+
},
|
23 |
+
"file_extension": ".py",
|
24 |
+
"mimetype": "text/x-python",
|
25 |
+
"name": "python",
|
26 |
+
"nbconvert_exporter": "python",
|
27 |
+
"pygments_lexer": "ipython3",
|
28 |
+
"version": "3.7.16"
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"nbformat": 4,
|
32 |
+
"nbformat_minor": 5
|
33 |
+
}
|
LoRA dataset/Weights/.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pytorch_lora_weights.safetensors filter=lfs diff=lfs merge=lfs -text
|
LoRA dataset/Weights/pytorch_lora_weights.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1b19610541f9a2c6f235a1bac2690d04b98535f9f9f7790e9ad4d0fe8ac89b0
|
3 |
+
size 3226184
|
MusicCaps/.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
transfer.pth filter=lfs diff=lfs merge=lfs -text
|
MusicCaps/__pycache__/audio_utils.cpython-310.pyc
ADDED
Binary file (7.7 kB). View file
|
|
MusicCaps/__pycache__/bart.cpython-310.pyc
ADDED
Binary file (4.57 kB). View file
|
|
MusicCaps/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (3.27 kB). View file
|
|
MusicCaps/audio_utils.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STR_CLIP_ID = 'clip_id'
|
2 |
+
STR_AUDIO_SIGNAL = 'audio_signal'
|
3 |
+
STR_TARGET_VECTOR = 'target_vector'
|
4 |
+
|
5 |
+
|
6 |
+
STR_CH_FIRST = 'channels_first'
|
7 |
+
STR_CH_LAST = 'channels_last'
|
8 |
+
|
9 |
+
import io
|
10 |
+
import os
|
11 |
+
import tqdm
|
12 |
+
import logging
|
13 |
+
import subprocess
|
14 |
+
from typing import Tuple
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
import librosa
|
18 |
+
import numpy as np
|
19 |
+
import soundfile as sf
|
20 |
+
|
21 |
+
import itertools
|
22 |
+
from numpy.fft import irfft
|
23 |
+
|
24 |
+
def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]:
|
25 |
+
"""
|
26 |
+
Decoding, downmixing, and downsampling by librosa.
|
27 |
+
Returns a channel-first audio signal.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
path:
|
31 |
+
sample_rate:
|
32 |
+
downmix_to_mono:
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
(audio signal, sample rate)
|
36 |
+
"""
|
37 |
+
|
38 |
+
def _decode_resample_by_ffmpeg(filename, sr):
|
39 |
+
"""decode, downmix, and resample audio file"""
|
40 |
+
channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option
|
41 |
+
resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option
|
42 |
+
cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -"
|
43 |
+
p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
44 |
+
out, err = p.communicate()
|
45 |
+
return out
|
46 |
+
|
47 |
+
src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate)))
|
48 |
+
return src.T, sr
|
49 |
+
|
50 |
+
|
51 |
+
def _resample_load_librosa(path, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]:
|
52 |
+
"""
|
53 |
+
Decoding, downmixing, and downsampling by librosa.
|
54 |
+
Returns a channel-first audio signal.
|
55 |
+
"""
|
56 |
+
src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs)
|
57 |
+
return src, sr
|
58 |
+
|
59 |
+
|
60 |
+
def load_audio(
|
61 |
+
path: str or Path,
|
62 |
+
ch_format: str,
|
63 |
+
sample_rate: int = None,
|
64 |
+
downmix_to_mono: bool = False,
|
65 |
+
resample_by: str = 'librosa',
|
66 |
+
**kwargs,
|
67 |
+
) -> Tuple[np.ndarray, int]:
|
68 |
+
"""A wrapper of librosa.load that:
|
69 |
+
- forces the returned audio to be 2-dim,
|
70 |
+
- defaults to sr=None, and
|
71 |
+
- defaults to downmix_to_mono=False.
|
72 |
+
|
73 |
+
The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg.
|
74 |
+
The resampling is done by `librosa`'s child package `resampy`.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
path: audio file path
|
78 |
+
ch_format: one of 'channels_first' or 'channels_last'
|
79 |
+
sample_rate: target sampling rate. if None, use the rate of the audio file
|
80 |
+
downmix_to_mono:
|
81 |
+
resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling.
|
82 |
+
**kwargs: keyword args for librosa.load - offset, duration, dtype, res_type.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
(audio, sr) tuple
|
86 |
+
"""
|
87 |
+
if ch_format not in (STR_CH_FIRST, STR_CH_LAST):
|
88 |
+
raise ValueError(f'ch_format is wrong here -> {ch_format}')
|
89 |
+
|
90 |
+
if resample_by == 'librosa':
|
91 |
+
src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs)
|
92 |
+
elif resample_by == 'ffmpeg':
|
93 |
+
src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono)
|
94 |
+
else:
|
95 |
+
raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet')
|
96 |
+
|
97 |
+
return src, sr
|
98 |
+
|
99 |
+
# if src.ndim == 1:
|
100 |
+
# src = np.expand_dims(src, axis=0)
|
101 |
+
# # now always 2d and channels_first
|
102 |
+
|
103 |
+
# if ch_format == STR_CH_FIRST:
|
104 |
+
# return src, sr
|
105 |
+
# else:
|
106 |
+
# return src.T, sr
|
107 |
+
|
108 |
+
def ms(x):
|
109 |
+
"""Mean value of signal `x` squared.
|
110 |
+
:param x: Dynamic quantity.
|
111 |
+
:returns: Mean squared of `x`.
|
112 |
+
"""
|
113 |
+
return (np.abs(x)**2.0).mean()
|
114 |
+
|
115 |
+
def normalize(y, x=None):
|
116 |
+
"""normalize power in y to a (standard normal) white noise signal.
|
117 |
+
Optionally normalize to power in signal `x`.
|
118 |
+
#The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1.
|
119 |
+
"""
|
120 |
+
if x is not None:
|
121 |
+
x = ms(x)
|
122 |
+
else:
|
123 |
+
x = 1.0
|
124 |
+
return y * np.sqrt(x / ms(y))
|
125 |
+
|
126 |
+
def noise(N, color='white', state=None):
|
127 |
+
"""Noise generator.
|
128 |
+
:param N: Amount of samples.
|
129 |
+
:param color: Color of noise.
|
130 |
+
:param state: State of PRNG.
|
131 |
+
:type state: :class:`np.random.RandomState`
|
132 |
+
"""
|
133 |
+
try:
|
134 |
+
return _noise_generators[color](N, state)
|
135 |
+
except KeyError:
|
136 |
+
raise ValueError("Incorrect color.")
|
137 |
+
|
138 |
+
def white(N, state=None):
|
139 |
+
"""
|
140 |
+
White noise.
|
141 |
+
:param N: Amount of samples.
|
142 |
+
:param state: State of PRNG.
|
143 |
+
:type state: :class:`np.random.RandomState`
|
144 |
+
White noise has a constant power density. It's narrowband spectrum is therefore flat.
|
145 |
+
The power in white noise will increase by a factor of two for each octave band,
|
146 |
+
and therefore increases with 3 dB per octave.
|
147 |
+
"""
|
148 |
+
state = np.random.RandomState() if state is None else state
|
149 |
+
return state.randn(N)
|
150 |
+
|
151 |
+
def pink(N, state=None):
|
152 |
+
"""
|
153 |
+
Pink noise.
|
154 |
+
:param N: Amount of samples.
|
155 |
+
:param state: State of PRNG.
|
156 |
+
:type state: :class:`np.random.RandomState`
|
157 |
+
Pink noise has equal power in bands that are proportionally wide.
|
158 |
+
Power density decreases with 3 dB per octave.
|
159 |
+
"""
|
160 |
+
state = np.random.RandomState() if state is None else state
|
161 |
+
uneven = N % 2
|
162 |
+
X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
|
163 |
+
S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero
|
164 |
+
y = (irfft(X / S)).real
|
165 |
+
if uneven:
|
166 |
+
y = y[:-1]
|
167 |
+
return normalize(y)
|
168 |
+
|
169 |
+
def blue(N, state=None):
|
170 |
+
"""
|
171 |
+
Blue noise.
|
172 |
+
:param N: Amount of samples.
|
173 |
+
:param state: State of PRNG.
|
174 |
+
:type state: :class:`np.random.RandomState`
|
175 |
+
Power increases with 6 dB per octave.
|
176 |
+
Power density increases with 3 dB per octave.
|
177 |
+
"""
|
178 |
+
state = np.random.RandomState() if state is None else state
|
179 |
+
uneven = N % 2
|
180 |
+
X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
|
181 |
+
S = np.sqrt(np.arange(len(X))) # Filter
|
182 |
+
y = (irfft(X * S)).real
|
183 |
+
if uneven:
|
184 |
+
y = y[:-1]
|
185 |
+
return normalize(y)
|
186 |
+
|
187 |
+
def brown(N, state=None):
|
188 |
+
"""
|
189 |
+
Violet noise.
|
190 |
+
:param N: Amount of samples.
|
191 |
+
:param state: State of PRNG.
|
192 |
+
:type state: :class:`np.random.RandomState`
|
193 |
+
Power decreases with -3 dB per octave.
|
194 |
+
Power density decreases with 6 dB per octave.
|
195 |
+
"""
|
196 |
+
state = np.random.RandomState() if state is None else state
|
197 |
+
uneven = N % 2
|
198 |
+
X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
|
199 |
+
S = (np.arange(len(X)) + 1) # Filter
|
200 |
+
y = (irfft(X / S)).real
|
201 |
+
if uneven:
|
202 |
+
y = y[:-1]
|
203 |
+
return normalize(y)
|
204 |
+
|
205 |
+
def violet(N, state=None):
|
206 |
+
"""
|
207 |
+
Violet noise. Power increases with 6 dB per octave.
|
208 |
+
:param N: Amount of samples.
|
209 |
+
:param state: State of PRNG.
|
210 |
+
:type state: :class:`np.random.RandomState`
|
211 |
+
Power increases with +9 dB per octave.
|
212 |
+
Power density increases with +6 dB per octave.
|
213 |
+
"""
|
214 |
+
state = np.random.RandomState() if state is None else state
|
215 |
+
uneven = N % 2
|
216 |
+
X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
|
217 |
+
S = (np.arange(len(X))) # Filter
|
218 |
+
y = (irfft(X * S)).real
|
219 |
+
if uneven:
|
220 |
+
y = y[:-1]
|
221 |
+
return normalize(y)
|
222 |
+
|
223 |
+
_noise_generators = {
|
224 |
+
'white': white,
|
225 |
+
'pink': pink,
|
226 |
+
'blue': blue,
|
227 |
+
'brown': brown,
|
228 |
+
'violet': violet,
|
229 |
+
}
|
230 |
+
|
231 |
+
def noise_generator(N=44100, color='white', state=None):
|
232 |
+
"""Noise generator.
|
233 |
+
:param N: Amount of unique samples to generate.
|
234 |
+
:param color: Color of noise.
|
235 |
+
Generate `N` amount of unique samples and cycle over these samples.
|
236 |
+
"""
|
237 |
+
#yield from itertools.cycle(noise(N, color)) # Python 3.3
|
238 |
+
for sample in itertools.cycle(noise(N, color, state)):
|
239 |
+
yield sample
|
240 |
+
|
241 |
+
def heaviside(N):
|
242 |
+
"""Heaviside.
|
243 |
+
Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`.
|
244 |
+
"""
|
245 |
+
return 0.5 * (np.sign(N) + 1)
|
MusicCaps/bart.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from MusicCaps.modules import AudioEncoder
|
6 |
+
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
|
7 |
+
|
8 |
+
class BartCaptionModel(nn.Module):
|
9 |
+
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
|
10 |
+
super(BartCaptionModel, self).__init__()
|
11 |
+
# non-finetunning case
|
12 |
+
bart_config = BartConfig.from_pretrained(bart_type)
|
13 |
+
self.tokenizer = BartTokenizer.from_pretrained(bart_type)
|
14 |
+
self.bart = BartForConditionalGeneration(bart_config)
|
15 |
+
|
16 |
+
self.n_sample = sr * duration
|
17 |
+
self.hop_length = int(0.01 * sr) # hard coding hop_size
|
18 |
+
self.n_frames = int(self.n_sample // self.hop_length)
|
19 |
+
self.num_of_stride_conv = num_of_conv - 1
|
20 |
+
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
|
21 |
+
self.audio_encoder = AudioEncoder(
|
22 |
+
n_mels = n_mels, # hard coding n_mel
|
23 |
+
n_ctx = self.n_ctx,
|
24 |
+
audio_dim = audio_dim,
|
25 |
+
text_dim = self.bart.config.hidden_size,
|
26 |
+
num_of_stride_conv = self.num_of_stride_conv
|
27 |
+
)
|
28 |
+
|
29 |
+
self.max_length = max_length
|
30 |
+
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def device(self):
|
34 |
+
return list(self.parameters())[0].device
|
35 |
+
|
36 |
+
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
37 |
+
"""
|
38 |
+
Shift input ids one token to the right.ls
|
39 |
+
"""
|
40 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
41 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
42 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
43 |
+
|
44 |
+
if pad_token_id is None:
|
45 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
46 |
+
# replace possible -100 values in labels by `pad_token_id`
|
47 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
48 |
+
return shifted_input_ids
|
49 |
+
|
50 |
+
def forward_encoder(self, audio):
|
51 |
+
audio_embs = self.audio_encoder(audio)
|
52 |
+
encoder_outputs = self.bart.model.encoder(
|
53 |
+
input_ids=None,
|
54 |
+
inputs_embeds=audio_embs,
|
55 |
+
return_dict=True
|
56 |
+
)["last_hidden_state"]
|
57 |
+
return encoder_outputs, audio_embs
|
58 |
+
|
59 |
+
def forward_decoder(self, text, encoder_outputs):
|
60 |
+
text = self.tokenizer(text,
|
61 |
+
padding='longest',
|
62 |
+
truncation=True,
|
63 |
+
max_length=self.max_length,
|
64 |
+
return_tensors="pt")
|
65 |
+
input_ids = text["input_ids"].to(self.device)
|
66 |
+
attention_mask = text["attention_mask"].to(self.device)
|
67 |
+
|
68 |
+
decoder_targets = input_ids.masked_fill(
|
69 |
+
input_ids == self.tokenizer.pad_token_id, -100
|
70 |
+
)
|
71 |
+
|
72 |
+
decoder_input_ids = self.shift_tokens_right(
|
73 |
+
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
|
74 |
+
)
|
75 |
+
|
76 |
+
decoder_outputs = self.bart(
|
77 |
+
input_ids=None,
|
78 |
+
attention_mask=None,
|
79 |
+
decoder_input_ids=decoder_input_ids,
|
80 |
+
decoder_attention_mask=attention_mask,
|
81 |
+
inputs_embeds=None,
|
82 |
+
labels=None,
|
83 |
+
encoder_outputs=(encoder_outputs,),
|
84 |
+
return_dict=True
|
85 |
+
)
|
86 |
+
lm_logits = decoder_outputs["logits"]
|
87 |
+
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
|
88 |
+
return loss
|
89 |
+
|
90 |
+
def forward(self, audio, text):
|
91 |
+
encoder_outputs, _ = self.forward_encoder(audio)
|
92 |
+
loss = self.forward_decoder(text, encoder_outputs)
|
93 |
+
return loss
|
94 |
+
|
95 |
+
def generate(self,
|
96 |
+
samples,
|
97 |
+
use_nucleus_sampling=False,
|
98 |
+
num_beams=5,
|
99 |
+
max_length=128,
|
100 |
+
min_length=2,
|
101 |
+
top_p=0.9,
|
102 |
+
repetition_penalty=1.0,
|
103 |
+
):
|
104 |
+
|
105 |
+
# self.bart.force_bos_token_to_be_generated = True
|
106 |
+
audio_embs = self.audio_encoder(samples)
|
107 |
+
encoder_outputs = self.bart.model.encoder(
|
108 |
+
input_ids=None,
|
109 |
+
attention_mask=None,
|
110 |
+
head_mask=None,
|
111 |
+
inputs_embeds=audio_embs,
|
112 |
+
output_attentions=None,
|
113 |
+
output_hidden_states=None,
|
114 |
+
return_dict=True)
|
115 |
+
|
116 |
+
input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
|
117 |
+
input_ids[:, 0] = self.bart.config.decoder_start_token_id
|
118 |
+
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
|
119 |
+
if use_nucleus_sampling:
|
120 |
+
outputs = self.bart.generate(
|
121 |
+
input_ids=None,
|
122 |
+
attention_mask=None,
|
123 |
+
decoder_input_ids=input_ids,
|
124 |
+
decoder_attention_mask=decoder_attention_mask,
|
125 |
+
encoder_outputs=encoder_outputs,
|
126 |
+
max_length=max_length,
|
127 |
+
min_length=min_length,
|
128 |
+
do_sample=True,
|
129 |
+
top_p=top_p,
|
130 |
+
num_return_sequences=1,
|
131 |
+
repetition_penalty=1.1)
|
132 |
+
else:
|
133 |
+
outputs = self.bart.generate(input_ids=None,
|
134 |
+
attention_mask=None,
|
135 |
+
decoder_input_ids=input_ids,
|
136 |
+
decoder_attention_mask=decoder_attention_mask,
|
137 |
+
encoder_outputs=encoder_outputs,
|
138 |
+
head_mask=None,
|
139 |
+
decoder_head_mask=None,
|
140 |
+
inputs_embeds=None,
|
141 |
+
decoder_inputs_embeds=None,
|
142 |
+
use_cache=None,
|
143 |
+
output_attentions=None,
|
144 |
+
output_hidden_states=None,
|
145 |
+
max_length=max_length,
|
146 |
+
min_length=min_length,
|
147 |
+
num_beams=num_beams,
|
148 |
+
repetition_penalty=repetition_penalty)
|
149 |
+
|
150 |
+
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
151 |
+
return captions
|
MusicCaps/modules.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from typing import Dict, Iterable, Optional
|
10 |
+
|
11 |
+
# hard-coded audio hyperparameters
|
12 |
+
SAMPLE_RATE = 16000
|
13 |
+
N_FFT = 1024
|
14 |
+
N_MELS = 128
|
15 |
+
HOP_LENGTH = int(0.01 * SAMPLE_RATE)
|
16 |
+
DURATION = 10
|
17 |
+
N_SAMPLES = int(DURATION * SAMPLE_RATE)
|
18 |
+
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1
|
19 |
+
|
20 |
+
def sinusoids(length, channels, max_timescale=10000):
|
21 |
+
"""Returns sinusoids for positional embedding"""
|
22 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
23 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
24 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
25 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
26 |
+
|
27 |
+
class MelEncoder(nn.Module):
|
28 |
+
"""
|
29 |
+
time-frequency represntation
|
30 |
+
"""
|
31 |
+
def __init__(self,
|
32 |
+
sample_rate= 16000,
|
33 |
+
f_min=0,
|
34 |
+
f_max=8000,
|
35 |
+
n_fft=1024,
|
36 |
+
win_length=1024,
|
37 |
+
hop_length = int(0.01 * 16000),
|
38 |
+
n_mels = 128,
|
39 |
+
power = None,
|
40 |
+
pad= 0,
|
41 |
+
normalized= False,
|
42 |
+
center= True,
|
43 |
+
pad_mode= "reflect"
|
44 |
+
):
|
45 |
+
super(MelEncoder, self).__init__()
|
46 |
+
self.window = torch.hann_window(win_length)
|
47 |
+
self.spec_fn = torchaudio.transforms.Spectrogram(
|
48 |
+
n_fft = n_fft,
|
49 |
+
win_length = win_length,
|
50 |
+
hop_length = hop_length,
|
51 |
+
power = power
|
52 |
+
)
|
53 |
+
self.mel_scale = torchaudio.transforms.MelScale(
|
54 |
+
n_mels,
|
55 |
+
sample_rate,
|
56 |
+
f_min,
|
57 |
+
f_max,
|
58 |
+
n_fft // 2 + 1)
|
59 |
+
|
60 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
61 |
+
|
62 |
+
def forward(self, wav):
|
63 |
+
spec = self.spec_fn(wav)
|
64 |
+
power_spec = spec.real.abs().pow(2)
|
65 |
+
mel_spec = self.mel_scale(power_spec)
|
66 |
+
mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
|
67 |
+
return mel_spec
|
68 |
+
|
69 |
+
class AudioEncoder(nn.Module):
|
70 |
+
def __init__(
|
71 |
+
self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
self.mel_encoder = MelEncoder(n_mels=n_mels)
|
75 |
+
self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
|
76 |
+
self.conv_stack = nn.ModuleList([])
|
77 |
+
for _ in range(num_of_stride_conv):
|
78 |
+
self.conv_stack.append(
|
79 |
+
nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
|
80 |
+
)
|
81 |
+
# self.proj = nn.Linear(audio_dim, text_dim, bias=False)
|
82 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))
|
83 |
+
|
84 |
+
def forward(self, x: Tensor):
|
85 |
+
"""
|
86 |
+
x : torch.Tensor, shape = (batch_size, waveform)
|
87 |
+
single channel wavform
|
88 |
+
"""
|
89 |
+
x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
|
90 |
+
x = F.gelu(self.conv1(x))
|
91 |
+
for conv in self.conv_stack:
|
92 |
+
x = F.gelu(conv(x))
|
93 |
+
x = x.permute(0, 2, 1)
|
94 |
+
x = (x + self.positional_embedding).to(x.dtype)
|
95 |
+
return x
|
MusicCaps/train_model.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bart import BartCaptionModel
|
2 |
+
from audio_utils import load_audio, STR_CH_FIRST
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
7 |
+
except:
|
8 |
+
print("1")
|
9 |
+
try:
|
10 |
+
model = BartCaptionModel(max_length = 128)
|
11 |
+
except:
|
12 |
+
print("2")
|
13 |
+
|
14 |
+
try:
|
15 |
+
pretrained_object = torch.load('transfer.pth', map_location='cpu')
|
16 |
+
except:
|
17 |
+
print("3")
|
18 |
+
|
19 |
+
try:
|
20 |
+
state_dict = pretrained_object['state_dict']
|
21 |
+
except:
|
22 |
+
print("4")
|
23 |
+
|
24 |
+
try:
|
25 |
+
model.load_state_dict(state_dict)
|
26 |
+
except:
|
27 |
+
print("5")
|
28 |
+
|
29 |
+
try:
|
30 |
+
torch.save(model,"model.pth")
|
31 |
+
except:
|
32 |
+
print("6")
|
MusicCaps/transfer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d04e457e045a09c7c5037222eaed3ffe35f8689b3753a2ce6094c5d5792f9bc
|
3 |
+
size 1783650705
|
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from timeit import default_timer as timer
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from MusicCaps.bart import BartCaptionModel
|
8 |
+
from MusicCaps.audio_utils import load_audio, STR_CH_FIRST
|
9 |
+
from diffusers import StableDiffusionPipeline, I2VGenXLPipeline
|
10 |
+
from diffusers.utils import export_to_video, load_image
|
11 |
+
import tensorflow as tf
|
12 |
+
import torch
|
13 |
+
|
14 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
15 |
+
if len(physical_devices) > 0:
|
16 |
+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
17 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
@st.cache_resource
|
20 |
+
def load_text_model():
|
21 |
+
model = BartCaptionModel(max_length = 128)
|
22 |
+
pretrained_object = torch.load('MusicCaps/transfer.pth', map_location='cpu')
|
23 |
+
state_dict = pretrained_object['state_dict']
|
24 |
+
model.load_state_dict(state_dict)
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
torch.cuda.set_device(device)
|
27 |
+
model.eval()
|
28 |
+
return model
|
29 |
+
|
30 |
+
def get_audio(audio_path, duration=10, target_sr=16000):
|
31 |
+
n_samples = int(duration * target_sr)
|
32 |
+
audio, sr = load_audio(
|
33 |
+
path= audio_path,
|
34 |
+
ch_format= STR_CH_FIRST,
|
35 |
+
sample_rate= target_sr,
|
36 |
+
downmix_to_mono= True,
|
37 |
+
)
|
38 |
+
if len(audio.shape) == 2:
|
39 |
+
audio = audio.mean(0, False) # to mono
|
40 |
+
input_size = int(n_samples)
|
41 |
+
if audio.shape[-1] < input_size: # pad sequence
|
42 |
+
pad = np.zeros(input_size)
|
43 |
+
pad[: audio.shape[-1]] = audio
|
44 |
+
audio = pad
|
45 |
+
ceil = int(audio.shape[-1] // n_samples)
|
46 |
+
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
|
47 |
+
return audio
|
48 |
+
|
49 |
+
def captioning(model,audio_path):
|
50 |
+
audio_tensor = get_audio(audio_path = audio_path)
|
51 |
+
if device is not None:
|
52 |
+
audio_tensor = audio_tensor.to(device)
|
53 |
+
with torch.no_grad():
|
54 |
+
output = model.generate(
|
55 |
+
samples=audio_tensor,
|
56 |
+
num_beams=5,
|
57 |
+
)
|
58 |
+
inference = []
|
59 |
+
number_of_chunks = range(audio_tensor.shape[0])
|
60 |
+
for chunk, text in zip(number_of_chunks, output):
|
61 |
+
output = ""
|
62 |
+
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
|
63 |
+
output += f"{time}\n{text} \n \n"
|
64 |
+
inference.append(output)
|
65 |
+
return inference
|
66 |
+
|
67 |
+
@st.cache_resource
|
68 |
+
def load_image_model():
|
69 |
+
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
|
70 |
+
pipeline.load_lora_weights("LoRA dataset/Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
|
71 |
+
return pipeline
|
72 |
+
|
73 |
+
@st.cache_resource
|
74 |
+
def load_video_model():
|
75 |
+
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
|
76 |
+
return pipeline
|
77 |
+
|
78 |
+
A2C_model = load_text_model()
|
79 |
+
image_service = load_image_model()
|
80 |
+
video_model = load_video_model()
|
81 |
+
|
82 |
+
if "audio_input" not in st.session_state:
|
83 |
+
st.session_state.audio_input = None
|
84 |
+
|
85 |
+
if "captions" not in st.session_state:
|
86 |
+
st.session_state.captions = None
|
87 |
+
|
88 |
+
if "image" not in st.session_state:
|
89 |
+
st.session_state.image = None
|
90 |
+
|
91 |
+
if "video" not in st.session_state:
|
92 |
+
st.session_state.video = None
|
93 |
+
|
94 |
+
st.title("Testing MusicCaps")
|
95 |
+
st.session_state.audio_input = st.file_uploader("Insert Your Audio Clips Here",type = ["wav","mp3"], key = "Audio input")
|
96 |
+
if st.session_state.audio_input:
|
97 |
+
audio_input = st.session_state.audio_input
|
98 |
+
st.audio(audio_input)
|
99 |
+
if st.button("Generate text prompt"):
|
100 |
+
st.session_state.captions = captioning(A2C_model,audio_input)[0]
|
101 |
+
captions = st.session_state.captions
|
102 |
+
st.text(captions)
|
103 |
+
if st.session_state.captions:
|
104 |
+
if st.button("Generate Image and video from text prompt"):
|
105 |
+
st.session_state.image = image_service(captions).images[0]
|
106 |
+
image = st.session_state.image
|
107 |
+
video = video_model(
|
108 |
+
prompt = captions,
|
109 |
+
image=image,
|
110 |
+
num_inference_steps=50
|
111 |
+
).frames[0]
|
112 |
+
st.session_state.video = video
|
113 |
+
export_to_video(video, "generated.mp4", fps=7)
|
114 |
+
c1,c2 = st.columns([1,1])
|
115 |
+
with c1:
|
116 |
+
st.image(image)
|
117 |
+
with c2:
|
118 |
+
st.video("generated.mp4")
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
requirements.txt
ADDED
Binary file (3.71 kB). View file
|
|