Monke64 commited on
Commit
728ab38
1 Parent(s): 162019a

Added code

Browse files
LoRA dataset/Training script/.ipynb_checkpoints/training_script-checkpoint.ipynb ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3b9bbec2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "kernelspec": {
14
+ "display_name": "Python 3 (ipykernel)",
15
+ "language": "python",
16
+ "name": "python3"
17
+ },
18
+ "language_info": {
19
+ "codemirror_mode": {
20
+ "name": "ipython",
21
+ "version": 3
22
+ },
23
+ "file_extension": ".py",
24
+ "mimetype": "text/x-python",
25
+ "name": "python",
26
+ "nbconvert_exporter": "python",
27
+ "pygments_lexer": "ipython3",
28
+ "version": "3.7.16"
29
+ }
30
+ },
31
+ "nbformat": 4,
32
+ "nbformat_minor": 5
33
+ }
LoRA dataset/Training script/training_script.ipynb ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3b9bbec2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "kernelspec": {
14
+ "display_name": "Python 3 (ipykernel)",
15
+ "language": "python",
16
+ "name": "python3"
17
+ },
18
+ "language_info": {
19
+ "codemirror_mode": {
20
+ "name": "ipython",
21
+ "version": 3
22
+ },
23
+ "file_extension": ".py",
24
+ "mimetype": "text/x-python",
25
+ "name": "python",
26
+ "nbconvert_exporter": "python",
27
+ "pygments_lexer": "ipython3",
28
+ "version": "3.7.16"
29
+ }
30
+ },
31
+ "nbformat": 4,
32
+ "nbformat_minor": 5
33
+ }
LoRA dataset/Weights/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ pytorch_lora_weights.safetensors filter=lfs diff=lfs merge=lfs -text
LoRA dataset/Weights/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1b19610541f9a2c6f235a1bac2690d04b98535f9f9f7790e9ad4d0fe8ac89b0
3
+ size 3226184
MusicCaps/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ transfer.pth filter=lfs diff=lfs merge=lfs -text
MusicCaps/__pycache__/audio_utils.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
MusicCaps/__pycache__/bart.cpython-310.pyc ADDED
Binary file (4.57 kB). View file
 
MusicCaps/__pycache__/modules.cpython-310.pyc ADDED
Binary file (3.27 kB). View file
 
MusicCaps/audio_utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STR_CLIP_ID = 'clip_id'
2
+ STR_AUDIO_SIGNAL = 'audio_signal'
3
+ STR_TARGET_VECTOR = 'target_vector'
4
+
5
+
6
+ STR_CH_FIRST = 'channels_first'
7
+ STR_CH_LAST = 'channels_last'
8
+
9
+ import io
10
+ import os
11
+ import tqdm
12
+ import logging
13
+ import subprocess
14
+ from typing import Tuple
15
+ from pathlib import Path
16
+
17
+ import librosa
18
+ import numpy as np
19
+ import soundfile as sf
20
+
21
+ import itertools
22
+ from numpy.fft import irfft
23
+
24
+ def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]:
25
+ """
26
+ Decoding, downmixing, and downsampling by librosa.
27
+ Returns a channel-first audio signal.
28
+
29
+ Args:
30
+ path:
31
+ sample_rate:
32
+ downmix_to_mono:
33
+
34
+ Returns:
35
+ (audio signal, sample rate)
36
+ """
37
+
38
+ def _decode_resample_by_ffmpeg(filename, sr):
39
+ """decode, downmix, and resample audio file"""
40
+ channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option
41
+ resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option
42
+ cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -"
43
+ p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
44
+ out, err = p.communicate()
45
+ return out
46
+
47
+ src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate)))
48
+ return src.T, sr
49
+
50
+
51
+ def _resample_load_librosa(path, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]:
52
+ """
53
+ Decoding, downmixing, and downsampling by librosa.
54
+ Returns a channel-first audio signal.
55
+ """
56
+ src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs)
57
+ return src, sr
58
+
59
+
60
+ def load_audio(
61
+ path: str or Path,
62
+ ch_format: str,
63
+ sample_rate: int = None,
64
+ downmix_to_mono: bool = False,
65
+ resample_by: str = 'librosa',
66
+ **kwargs,
67
+ ) -> Tuple[np.ndarray, int]:
68
+ """A wrapper of librosa.load that:
69
+ - forces the returned audio to be 2-dim,
70
+ - defaults to sr=None, and
71
+ - defaults to downmix_to_mono=False.
72
+
73
+ The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg.
74
+ The resampling is done by `librosa`'s child package `resampy`.
75
+
76
+ Args:
77
+ path: audio file path
78
+ ch_format: one of 'channels_first' or 'channels_last'
79
+ sample_rate: target sampling rate. if None, use the rate of the audio file
80
+ downmix_to_mono:
81
+ resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling.
82
+ **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type.
83
+
84
+ Returns:
85
+ (audio, sr) tuple
86
+ """
87
+ if ch_format not in (STR_CH_FIRST, STR_CH_LAST):
88
+ raise ValueError(f'ch_format is wrong here -> {ch_format}')
89
+
90
+ if resample_by == 'librosa':
91
+ src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs)
92
+ elif resample_by == 'ffmpeg':
93
+ src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono)
94
+ else:
95
+ raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet')
96
+
97
+ return src, sr
98
+
99
+ # if src.ndim == 1:
100
+ # src = np.expand_dims(src, axis=0)
101
+ # # now always 2d and channels_first
102
+
103
+ # if ch_format == STR_CH_FIRST:
104
+ # return src, sr
105
+ # else:
106
+ # return src.T, sr
107
+
108
+ def ms(x):
109
+ """Mean value of signal `x` squared.
110
+ :param x: Dynamic quantity.
111
+ :returns: Mean squared of `x`.
112
+ """
113
+ return (np.abs(x)**2.0).mean()
114
+
115
+ def normalize(y, x=None):
116
+ """normalize power in y to a (standard normal) white noise signal.
117
+ Optionally normalize to power in signal `x`.
118
+ #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1.
119
+ """
120
+ if x is not None:
121
+ x = ms(x)
122
+ else:
123
+ x = 1.0
124
+ return y * np.sqrt(x / ms(y))
125
+
126
+ def noise(N, color='white', state=None):
127
+ """Noise generator.
128
+ :param N: Amount of samples.
129
+ :param color: Color of noise.
130
+ :param state: State of PRNG.
131
+ :type state: :class:`np.random.RandomState`
132
+ """
133
+ try:
134
+ return _noise_generators[color](N, state)
135
+ except KeyError:
136
+ raise ValueError("Incorrect color.")
137
+
138
+ def white(N, state=None):
139
+ """
140
+ White noise.
141
+ :param N: Amount of samples.
142
+ :param state: State of PRNG.
143
+ :type state: :class:`np.random.RandomState`
144
+ White noise has a constant power density. It's narrowband spectrum is therefore flat.
145
+ The power in white noise will increase by a factor of two for each octave band,
146
+ and therefore increases with 3 dB per octave.
147
+ """
148
+ state = np.random.RandomState() if state is None else state
149
+ return state.randn(N)
150
+
151
+ def pink(N, state=None):
152
+ """
153
+ Pink noise.
154
+ :param N: Amount of samples.
155
+ :param state: State of PRNG.
156
+ :type state: :class:`np.random.RandomState`
157
+ Pink noise has equal power in bands that are proportionally wide.
158
+ Power density decreases with 3 dB per octave.
159
+ """
160
+ state = np.random.RandomState() if state is None else state
161
+ uneven = N % 2
162
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
163
+ S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero
164
+ y = (irfft(X / S)).real
165
+ if uneven:
166
+ y = y[:-1]
167
+ return normalize(y)
168
+
169
+ def blue(N, state=None):
170
+ """
171
+ Blue noise.
172
+ :param N: Amount of samples.
173
+ :param state: State of PRNG.
174
+ :type state: :class:`np.random.RandomState`
175
+ Power increases with 6 dB per octave.
176
+ Power density increases with 3 dB per octave.
177
+ """
178
+ state = np.random.RandomState() if state is None else state
179
+ uneven = N % 2
180
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
181
+ S = np.sqrt(np.arange(len(X))) # Filter
182
+ y = (irfft(X * S)).real
183
+ if uneven:
184
+ y = y[:-1]
185
+ return normalize(y)
186
+
187
+ def brown(N, state=None):
188
+ """
189
+ Violet noise.
190
+ :param N: Amount of samples.
191
+ :param state: State of PRNG.
192
+ :type state: :class:`np.random.RandomState`
193
+ Power decreases with -3 dB per octave.
194
+ Power density decreases with 6 dB per octave.
195
+ """
196
+ state = np.random.RandomState() if state is None else state
197
+ uneven = N % 2
198
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
199
+ S = (np.arange(len(X)) + 1) # Filter
200
+ y = (irfft(X / S)).real
201
+ if uneven:
202
+ y = y[:-1]
203
+ return normalize(y)
204
+
205
+ def violet(N, state=None):
206
+ """
207
+ Violet noise. Power increases with 6 dB per octave.
208
+ :param N: Amount of samples.
209
+ :param state: State of PRNG.
210
+ :type state: :class:`np.random.RandomState`
211
+ Power increases with +9 dB per octave.
212
+ Power density increases with +6 dB per octave.
213
+ """
214
+ state = np.random.RandomState() if state is None else state
215
+ uneven = N % 2
216
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
217
+ S = (np.arange(len(X))) # Filter
218
+ y = (irfft(X * S)).real
219
+ if uneven:
220
+ y = y[:-1]
221
+ return normalize(y)
222
+
223
+ _noise_generators = {
224
+ 'white': white,
225
+ 'pink': pink,
226
+ 'blue': blue,
227
+ 'brown': brown,
228
+ 'violet': violet,
229
+ }
230
+
231
+ def noise_generator(N=44100, color='white', state=None):
232
+ """Noise generator.
233
+ :param N: Amount of unique samples to generate.
234
+ :param color: Color of noise.
235
+ Generate `N` amount of unique samples and cycle over these samples.
236
+ """
237
+ #yield from itertools.cycle(noise(N, color)) # Python 3.3
238
+ for sample in itertools.cycle(noise(N, color, state)):
239
+ yield sample
240
+
241
+ def heaviside(N):
242
+ """Heaviside.
243
+ Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`.
244
+ """
245
+ return 0.5 * (np.sign(N) + 1)
MusicCaps/bart.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from MusicCaps.modules import AudioEncoder
6
+ from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
7
+
8
+ class BartCaptionModel(nn.Module):
9
+ def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
10
+ super(BartCaptionModel, self).__init__()
11
+ # non-finetunning case
12
+ bart_config = BartConfig.from_pretrained(bart_type)
13
+ self.tokenizer = BartTokenizer.from_pretrained(bart_type)
14
+ self.bart = BartForConditionalGeneration(bart_config)
15
+
16
+ self.n_sample = sr * duration
17
+ self.hop_length = int(0.01 * sr) # hard coding hop_size
18
+ self.n_frames = int(self.n_sample // self.hop_length)
19
+ self.num_of_stride_conv = num_of_conv - 1
20
+ self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
21
+ self.audio_encoder = AudioEncoder(
22
+ n_mels = n_mels, # hard coding n_mel
23
+ n_ctx = self.n_ctx,
24
+ audio_dim = audio_dim,
25
+ text_dim = self.bart.config.hidden_size,
26
+ num_of_stride_conv = self.num_of_stride_conv
27
+ )
28
+
29
+ self.max_length = max_length
30
+ self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
31
+
32
+ @property
33
+ def device(self):
34
+ return list(self.parameters())[0].device
35
+
36
+ def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
37
+ """
38
+ Shift input ids one token to the right.ls
39
+ """
40
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
41
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
42
+ shifted_input_ids[:, 0] = decoder_start_token_id
43
+
44
+ if pad_token_id is None:
45
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
46
+ # replace possible -100 values in labels by `pad_token_id`
47
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
48
+ return shifted_input_ids
49
+
50
+ def forward_encoder(self, audio):
51
+ audio_embs = self.audio_encoder(audio)
52
+ encoder_outputs = self.bart.model.encoder(
53
+ input_ids=None,
54
+ inputs_embeds=audio_embs,
55
+ return_dict=True
56
+ )["last_hidden_state"]
57
+ return encoder_outputs, audio_embs
58
+
59
+ def forward_decoder(self, text, encoder_outputs):
60
+ text = self.tokenizer(text,
61
+ padding='longest',
62
+ truncation=True,
63
+ max_length=self.max_length,
64
+ return_tensors="pt")
65
+ input_ids = text["input_ids"].to(self.device)
66
+ attention_mask = text["attention_mask"].to(self.device)
67
+
68
+ decoder_targets = input_ids.masked_fill(
69
+ input_ids == self.tokenizer.pad_token_id, -100
70
+ )
71
+
72
+ decoder_input_ids = self.shift_tokens_right(
73
+ decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
74
+ )
75
+
76
+ decoder_outputs = self.bart(
77
+ input_ids=None,
78
+ attention_mask=None,
79
+ decoder_input_ids=decoder_input_ids,
80
+ decoder_attention_mask=attention_mask,
81
+ inputs_embeds=None,
82
+ labels=None,
83
+ encoder_outputs=(encoder_outputs,),
84
+ return_dict=True
85
+ )
86
+ lm_logits = decoder_outputs["logits"]
87
+ loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
88
+ return loss
89
+
90
+ def forward(self, audio, text):
91
+ encoder_outputs, _ = self.forward_encoder(audio)
92
+ loss = self.forward_decoder(text, encoder_outputs)
93
+ return loss
94
+
95
+ def generate(self,
96
+ samples,
97
+ use_nucleus_sampling=False,
98
+ num_beams=5,
99
+ max_length=128,
100
+ min_length=2,
101
+ top_p=0.9,
102
+ repetition_penalty=1.0,
103
+ ):
104
+
105
+ # self.bart.force_bos_token_to_be_generated = True
106
+ audio_embs = self.audio_encoder(samples)
107
+ encoder_outputs = self.bart.model.encoder(
108
+ input_ids=None,
109
+ attention_mask=None,
110
+ head_mask=None,
111
+ inputs_embeds=audio_embs,
112
+ output_attentions=None,
113
+ output_hidden_states=None,
114
+ return_dict=True)
115
+
116
+ input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
117
+ input_ids[:, 0] = self.bart.config.decoder_start_token_id
118
+ decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
119
+ if use_nucleus_sampling:
120
+ outputs = self.bart.generate(
121
+ input_ids=None,
122
+ attention_mask=None,
123
+ decoder_input_ids=input_ids,
124
+ decoder_attention_mask=decoder_attention_mask,
125
+ encoder_outputs=encoder_outputs,
126
+ max_length=max_length,
127
+ min_length=min_length,
128
+ do_sample=True,
129
+ top_p=top_p,
130
+ num_return_sequences=1,
131
+ repetition_penalty=1.1)
132
+ else:
133
+ outputs = self.bart.generate(input_ids=None,
134
+ attention_mask=None,
135
+ decoder_input_ids=input_ids,
136
+ decoder_attention_mask=decoder_attention_mask,
137
+ encoder_outputs=encoder_outputs,
138
+ head_mask=None,
139
+ decoder_head_mask=None,
140
+ inputs_embeds=None,
141
+ decoder_inputs_embeds=None,
142
+ use_cache=None,
143
+ output_attentions=None,
144
+ output_hidden_states=None,
145
+ max_length=max_length,
146
+ min_length=min_length,
147
+ num_beams=num_beams,
148
+ repetition_penalty=repetition_penalty)
149
+
150
+ captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
151
+ return captions
MusicCaps/modules.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
2
+
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from torch import Tensor, nn
9
+ from typing import Dict, Iterable, Optional
10
+
11
+ # hard-coded audio hyperparameters
12
+ SAMPLE_RATE = 16000
13
+ N_FFT = 1024
14
+ N_MELS = 128
15
+ HOP_LENGTH = int(0.01 * SAMPLE_RATE)
16
+ DURATION = 10
17
+ N_SAMPLES = int(DURATION * SAMPLE_RATE)
18
+ N_FRAMES = N_SAMPLES // HOP_LENGTH + 1
19
+
20
+ def sinusoids(length, channels, max_timescale=10000):
21
+ """Returns sinusoids for positional embedding"""
22
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
23
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
24
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
25
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
26
+
27
+ class MelEncoder(nn.Module):
28
+ """
29
+ time-frequency represntation
30
+ """
31
+ def __init__(self,
32
+ sample_rate= 16000,
33
+ f_min=0,
34
+ f_max=8000,
35
+ n_fft=1024,
36
+ win_length=1024,
37
+ hop_length = int(0.01 * 16000),
38
+ n_mels = 128,
39
+ power = None,
40
+ pad= 0,
41
+ normalized= False,
42
+ center= True,
43
+ pad_mode= "reflect"
44
+ ):
45
+ super(MelEncoder, self).__init__()
46
+ self.window = torch.hann_window(win_length)
47
+ self.spec_fn = torchaudio.transforms.Spectrogram(
48
+ n_fft = n_fft,
49
+ win_length = win_length,
50
+ hop_length = hop_length,
51
+ power = power
52
+ )
53
+ self.mel_scale = torchaudio.transforms.MelScale(
54
+ n_mels,
55
+ sample_rate,
56
+ f_min,
57
+ f_max,
58
+ n_fft // 2 + 1)
59
+
60
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
61
+
62
+ def forward(self, wav):
63
+ spec = self.spec_fn(wav)
64
+ power_spec = spec.real.abs().pow(2)
65
+ mel_spec = self.mel_scale(power_spec)
66
+ mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
67
+ return mel_spec
68
+
69
+ class AudioEncoder(nn.Module):
70
+ def __init__(
71
+ self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
72
+ ):
73
+ super().__init__()
74
+ self.mel_encoder = MelEncoder(n_mels=n_mels)
75
+ self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
76
+ self.conv_stack = nn.ModuleList([])
77
+ for _ in range(num_of_stride_conv):
78
+ self.conv_stack.append(
79
+ nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
80
+ )
81
+ # self.proj = nn.Linear(audio_dim, text_dim, bias=False)
82
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))
83
+
84
+ def forward(self, x: Tensor):
85
+ """
86
+ x : torch.Tensor, shape = (batch_size, waveform)
87
+ single channel wavform
88
+ """
89
+ x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
90
+ x = F.gelu(self.conv1(x))
91
+ for conv in self.conv_stack:
92
+ x = F.gelu(conv(x))
93
+ x = x.permute(0, 2, 1)
94
+ x = (x + self.positional_embedding).to(x.dtype)
95
+ return x
MusicCaps/train_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bart import BartCaptionModel
2
+ from audio_utils import load_audio, STR_CH_FIRST
3
+ import torch
4
+
5
+ try:
6
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+ except:
8
+ print("1")
9
+ try:
10
+ model = BartCaptionModel(max_length = 128)
11
+ except:
12
+ print("2")
13
+
14
+ try:
15
+ pretrained_object = torch.load('transfer.pth', map_location='cpu')
16
+ except:
17
+ print("3")
18
+
19
+ try:
20
+ state_dict = pretrained_object['state_dict']
21
+ except:
22
+ print("4")
23
+
24
+ try:
25
+ model.load_state_dict(state_dict)
26
+ except:
27
+ print("5")
28
+
29
+ try:
30
+ torch.save(model,"model.pth")
31
+ except:
32
+ print("6")
MusicCaps/transfer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d04e457e045a09c7c5037222eaed3ffe35f8689b3753a2ce6094c5d5792f9bc
3
+ size 1783650705
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from timeit import default_timer as timer
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from huggingface_hub import hf_hub_download
7
+ from MusicCaps.bart import BartCaptionModel
8
+ from MusicCaps.audio_utils import load_audio, STR_CH_FIRST
9
+ from diffusers import StableDiffusionPipeline, I2VGenXLPipeline
10
+ from diffusers.utils import export_to_video, load_image
11
+ import tensorflow as tf
12
+ import torch
13
+
14
+ physical_devices = tf.config.experimental.list_physical_devices('GPU')
15
+ if len(physical_devices) > 0:
16
+ tf.config.experimental.set_memory_growth(physical_devices[0], True)
17
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+
19
+ @st.cache_resource
20
+ def load_text_model():
21
+ model = BartCaptionModel(max_length = 128)
22
+ pretrained_object = torch.load('MusicCaps/transfer.pth', map_location='cpu')
23
+ state_dict = pretrained_object['state_dict']
24
+ model.load_state_dict(state_dict)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.set_device(device)
27
+ model.eval()
28
+ return model
29
+
30
+ def get_audio(audio_path, duration=10, target_sr=16000):
31
+ n_samples = int(duration * target_sr)
32
+ audio, sr = load_audio(
33
+ path= audio_path,
34
+ ch_format= STR_CH_FIRST,
35
+ sample_rate= target_sr,
36
+ downmix_to_mono= True,
37
+ )
38
+ if len(audio.shape) == 2:
39
+ audio = audio.mean(0, False) # to mono
40
+ input_size = int(n_samples)
41
+ if audio.shape[-1] < input_size: # pad sequence
42
+ pad = np.zeros(input_size)
43
+ pad[: audio.shape[-1]] = audio
44
+ audio = pad
45
+ ceil = int(audio.shape[-1] // n_samples)
46
+ audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
47
+ return audio
48
+
49
+ def captioning(model,audio_path):
50
+ audio_tensor = get_audio(audio_path = audio_path)
51
+ if device is not None:
52
+ audio_tensor = audio_tensor.to(device)
53
+ with torch.no_grad():
54
+ output = model.generate(
55
+ samples=audio_tensor,
56
+ num_beams=5,
57
+ )
58
+ inference = []
59
+ number_of_chunks = range(audio_tensor.shape[0])
60
+ for chunk, text in zip(number_of_chunks, output):
61
+ output = ""
62
+ time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
63
+ output += f"{time}\n{text} \n \n"
64
+ inference.append(output)
65
+ return inference
66
+
67
+ @st.cache_resource
68
+ def load_image_model():
69
+ pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
70
+ pipeline.load_lora_weights("LoRA dataset/Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
71
+ return pipeline
72
+
73
+ @st.cache_resource
74
+ def load_video_model():
75
+ pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
76
+ return pipeline
77
+
78
+ A2C_model = load_text_model()
79
+ image_service = load_image_model()
80
+ video_model = load_video_model()
81
+
82
+ if "audio_input" not in st.session_state:
83
+ st.session_state.audio_input = None
84
+
85
+ if "captions" not in st.session_state:
86
+ st.session_state.captions = None
87
+
88
+ if "image" not in st.session_state:
89
+ st.session_state.image = None
90
+
91
+ if "video" not in st.session_state:
92
+ st.session_state.video = None
93
+
94
+ st.title("Testing MusicCaps")
95
+ st.session_state.audio_input = st.file_uploader("Insert Your Audio Clips Here",type = ["wav","mp3"], key = "Audio input")
96
+ if st.session_state.audio_input:
97
+ audio_input = st.session_state.audio_input
98
+ st.audio(audio_input)
99
+ if st.button("Generate text prompt"):
100
+ st.session_state.captions = captioning(A2C_model,audio_input)[0]
101
+ captions = st.session_state.captions
102
+ st.text(captions)
103
+ if st.session_state.captions:
104
+ if st.button("Generate Image and video from text prompt"):
105
+ st.session_state.image = image_service(captions).images[0]
106
+ image = st.session_state.image
107
+ video = video_model(
108
+ prompt = captions,
109
+ image=image,
110
+ num_inference_steps=50
111
+ ).frames[0]
112
+ st.session_state.video = video
113
+ export_to_video(video, "generated.mp4", fps=7)
114
+ c1,c2 = st.columns([1,1])
115
+ with c1:
116
+ st.image(image)
117
+ with c2:
118
+ st.video("generated.mp4")
119
+
120
+
121
+
122
+
123
+
124
+
125
+
requirements.txt ADDED
Binary file (3.71 kB). View file