fx sounds batch inference
Browse files- README.md +2 -2
- api.py +7 -15
- audiocraft/builders.py +10 -40
- audiocraft/conditioners.py +24 -198
- audiocraft/lm.py +7 -23
- audiocraft/transformer.py +1 -4
- models.py +5 -1
- requirements.txt +19 -0
README.md
CHANGED
@@ -62,7 +62,7 @@ pip install -r requirements.txt
|
|
62 |
Flask `tmux-session`
|
63 |
|
64 |
```
|
65 |
-
CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME
|
66 |
```
|
67 |
|
68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
@@ -127,5 +127,5 @@ Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.y
|
|
127 |
|
128 |
```python
|
129 |
# audiobook will be saved in ./tts_audiobooks
|
130 |
-
python audiobook.py
|
131 |
```
|
|
|
62 |
Flask `tmux-session`
|
63 |
|
64 |
```
|
65 |
+
CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=0 python api.py
|
66 |
```
|
67 |
|
68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
|
|
127 |
|
128 |
```python
|
129 |
# audiobook will be saved in ./tts_audiobooks
|
130 |
+
python audiobook.py
|
131 |
```
|
api.py
CHANGED
@@ -9,12 +9,11 @@ import re
|
|
9 |
import srt
|
10 |
import subprocess
|
11 |
import cv2
|
12 |
-
import markdown
|
13 |
from pathlib import Path
|
14 |
from types import SimpleNamespace
|
15 |
from flask import Flask, request, send_from_directory
|
16 |
-
from
|
17 |
-
from moviepy.
|
18 |
from audiocraft.builders import AudioGen
|
19 |
CACHE_DIR = 'flask_cache/'
|
20 |
NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
|
@@ -79,10 +78,10 @@ def overlay(x, soundscape=None):
|
|
79 |
background = sound_generator.generate(
|
80 |
[soundscape] * NUM_SOUND_GENERATIONS
|
81 |
).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
|
82 |
-
|
83 |
-
# upsample 16 kHz AudioGen to 24kHZ
|
84 |
|
85 |
-
print('Resampling')
|
86 |
|
87 |
|
88 |
background = audresample.resample(
|
@@ -178,14 +177,6 @@ def tts_multi_sentence(precomputed_style_vector=None,
|
|
178 |
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
179 |
|
180 |
app = Flask(__name__)
|
181 |
-
cors = CORS(app)
|
182 |
-
|
183 |
-
|
184 |
-
@app.route("/")
|
185 |
-
def index():
|
186 |
-
with open('README.md', 'r') as f:
|
187 |
-
return markdown.markdown(f.read())
|
188 |
-
|
189 |
|
190 |
@app.route("/", methods=['GET', 'POST', 'PUT'])
|
191 |
def serve_wav():
|
@@ -460,7 +451,8 @@ def serve_wav():
|
|
460 |
|
461 |
# SILENT CLIP
|
462 |
|
463 |
-
clip_silent = ImageClip(STATIC_FRAME
|
|
|
464 |
clip_silent.write_videofile(SILENT_VIDEO, fps=24)
|
465 |
|
466 |
x = tts_multi_sentence(text=text,
|
|
|
9 |
import srt
|
10 |
import subprocess
|
11 |
import cv2
|
|
|
12 |
from pathlib import Path
|
13 |
from types import SimpleNamespace
|
14 |
from flask import Flask, request, send_from_directory
|
15 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
16 |
+
from moviepy.video.VideoClip import ImageClip
|
17 |
from audiocraft.builders import AudioGen
|
18 |
CACHE_DIR = 'flask_cache/'
|
19 |
NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
|
|
|
78 |
background = sound_generator.generate(
|
79 |
[soundscape] * NUM_SOUND_GENERATIONS
|
80 |
).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
|
81 |
+
|
82 |
+
# upsample 16 kHz AudioGen to 24kHZ of VITS/StyleTTS2
|
83 |
|
84 |
+
print('Resampling') # soundscape each generation in batch differs from the other generations thus clone/shift each element in batch, finally concat w/o shift
|
85 |
|
86 |
|
87 |
background = audresample.resample(
|
|
|
177 |
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
178 |
|
179 |
app = Flask(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
@app.route("/", methods=['GET', 'POST', 'PUT'])
|
182 |
def serve_wav():
|
|
|
451 |
|
452 |
# SILENT CLIP
|
453 |
|
454 |
+
clip_silent = ImageClip(img=STATIC_FRAME,
|
455 |
+
duration=5) # ffmpeg continues this silent video for duration of TTS
|
456 |
clip_silent.write_videofile(SILENT_VIDEO, fps=24)
|
457 |
|
458 |
x = tts_multi_sentence(text=text,
|
audiocraft/builders.py
CHANGED
@@ -10,11 +10,7 @@ from .encodec import EncodecModel
|
|
10 |
from .lm import LMModel
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .codebooks_patterns import DelayedPatternProvider
|
13 |
-
from .conditioners import
|
14 |
-
ConditioningProvider,
|
15 |
-
T5Conditioner,
|
16 |
-
ConditioningAttributes
|
17 |
-
)
|
18 |
from .vq import ResidualVectorQuantizer
|
19 |
|
20 |
|
@@ -73,10 +69,8 @@ class AudioGen(nn.Module):
|
|
73 |
def generate(self,
|
74 |
descriptions):
|
75 |
with torch.no_grad():
|
76 |
-
attributes = [
|
77 |
-
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
78 |
gen_tokens = self.lm.generate(
|
79 |
-
|
80 |
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
81 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
82 |
# print('______________\nAudioGen Tokens', gen_tokens)
|
@@ -144,10 +138,8 @@ class AudioGen(nn.Module):
|
|
144 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
145 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
146 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
147 |
-
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
148 |
|
149 |
-
condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
|
150 |
-
).to(self.device)
|
151 |
|
152 |
|
153 |
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
@@ -163,7 +155,7 @@ class AudioGen(nn.Module):
|
|
163 |
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
164 |
return LMModel(
|
165 |
pattern_provider=pattern_provider,
|
166 |
-
condition_provider=
|
167 |
cfg_dropout=cfg_prob,
|
168 |
cfg_coef=cfg_coef,
|
169 |
attribute_dropout=attribute_dropout,
|
@@ -173,34 +165,8 @@ class AudioGen(nn.Module):
|
|
173 |
).to(cfg.device)
|
174 |
else:
|
175 |
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
176 |
-
|
177 |
-
|
178 |
-
def get_conditioner_provider(self, output_dim,
|
179 |
-
cfg):
|
180 |
-
"""Instantiate T5 text"""
|
181 |
-
cfg = getattr(cfg, 'conditioners')
|
182 |
-
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
183 |
-
conditioners={}
|
184 |
-
condition_provider_args = dict_cfg.pop('args', {})
|
185 |
-
condition_provider_args.pop('merge_text_conditions_p', None)
|
186 |
-
condition_provider_args.pop('drop_desc_p', None)
|
187 |
-
|
188 |
-
for cond, cond_cfg in dict_cfg.items():
|
189 |
-
model_type = cond_cfg['model']
|
190 |
-
model_args = cond_cfg[model_type]
|
191 |
-
if model_type == 't5':
|
192 |
-
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
|
193 |
-
device=self.device,
|
194 |
-
**model_args)
|
195 |
-
else:
|
196 |
-
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
197 |
-
|
198 |
-
# print(f'{condition_provider_args=}')
|
199 |
-
return ConditioningProvider(conditioners)
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
def get_codebooks_pattern_provider(self, n_q, cfg):
|
205 |
pattern_providers = {
|
206 |
'delay': DelayedPatternProvider, # THIS
|
@@ -249,6 +215,10 @@ class AudioGen(nn.Module):
|
|
249 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
250 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
251 |
model = self.get_lm_model(cfg)
|
|
|
|
|
|
|
|
|
252 |
model.load_state_dict(pkg['best_state'])
|
253 |
model.cfg = cfg
|
254 |
# return model
|
|
|
10 |
from .lm import LMModel
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .codebooks_patterns import DelayedPatternProvider
|
13 |
+
from .conditioners import T5Conditioner
|
|
|
|
|
|
|
|
|
14 |
from .vq import ResidualVectorQuantizer
|
15 |
|
16 |
|
|
|
69 |
def generate(self,
|
70 |
descriptions):
|
71 |
with torch.no_grad():
|
|
|
|
|
72 |
gen_tokens = self.lm.generate(
|
73 |
+
descriptions=descriptions,
|
74 |
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
75 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
76 |
# print('______________\nAudioGen Tokens', gen_tokens)
|
|
|
138 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
139 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
140 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
141 |
+
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
142 |
|
|
|
|
|
143 |
|
144 |
|
145 |
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
|
|
155 |
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
156 |
return LMModel(
|
157 |
pattern_provider=pattern_provider,
|
158 |
+
condition_provider=T5Conditioner(name='t5-large', output_dim=kwargs["dim"], device=self.device),
|
159 |
cfg_dropout=cfg_prob,
|
160 |
cfg_coef=cfg_coef,
|
161 |
attribute_dropout=attribute_dropout,
|
|
|
165 |
).to(cfg.device)
|
166 |
else:
|
167 |
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
168 |
+
|
169 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def get_codebooks_pattern_provider(self, n_q, cfg):
|
171 |
pattern_providers = {
|
172 |
'delay': DelayedPatternProvider, # THIS
|
|
|
215 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
216 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
217 |
model = self.get_lm_model(cfg)
|
218 |
+
|
219 |
+
_best = pkg['best_state']
|
220 |
+
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
221 |
+
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
222 |
model.load_state_dict(pkg['best_state'])
|
223 |
model.cfg = cfg
|
224 |
# return model
|
audiocraft/conditioners.py
CHANGED
@@ -1,82 +1,9 @@
|
|
1 |
-
from collections import defaultdict
|
2 |
-
from dataclasses import dataclass, field
|
3 |
-
import logging
|
4 |
-
import random
|
5 |
-
import typing as tp
|
6 |
import warnings
|
7 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
8 |
import torch
|
9 |
from torch import nn
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
12 |
-
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
13 |
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
class JointEmbedCondition(tp.NamedTuple):
|
18 |
-
wav: torch.Tensor
|
19 |
-
text: tp.List[tp.Optional[str]]
|
20 |
-
length: torch.Tensor
|
21 |
-
sample_rate: tp.List[int]
|
22 |
-
path: tp.List[tp.Optional[str]] = []
|
23 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
24 |
-
|
25 |
-
|
26 |
-
@dataclass
|
27 |
-
class ConditioningAttributes:
|
28 |
-
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
29 |
-
wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
30 |
-
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
31 |
-
|
32 |
-
def __getitem__(self, item):
|
33 |
-
return getattr(self, item)
|
34 |
-
|
35 |
-
@property
|
36 |
-
def text_attributes(self):
|
37 |
-
return self.text.keys()
|
38 |
-
|
39 |
-
@property
|
40 |
-
def wav_attributes(self):
|
41 |
-
return self.wav.keys()
|
42 |
-
|
43 |
-
@property
|
44 |
-
def joint_embed_attributes(self):
|
45 |
-
return self.joint_embed.keys()
|
46 |
-
|
47 |
-
@property
|
48 |
-
def attributes(self):
|
49 |
-
return {
|
50 |
-
"text": self.text_attributes,
|
51 |
-
"wav": self.wav_attributes,
|
52 |
-
"joint_embed": self.joint_embed_attributes,
|
53 |
-
}
|
54 |
-
|
55 |
-
def to_flat_dict(self):
|
56 |
-
return {
|
57 |
-
**{f"text.{k}": v for k, v in self.text.items()},
|
58 |
-
**{f"wav.{k}": v for k, v in self.wav.items()},
|
59 |
-
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
60 |
-
}
|
61 |
-
|
62 |
-
@classmethod
|
63 |
-
def from_flat_dict(cls, x):
|
64 |
-
out = cls()
|
65 |
-
for k, v in x.items():
|
66 |
-
kind, att = k.split(".")
|
67 |
-
out[kind][att] = v
|
68 |
-
return out
|
69 |
-
|
70 |
-
|
71 |
-
class Tokenizer:
|
72 |
-
"""Base tokenizer implementation
|
73 |
-
(in case we want to introduce more advances tokenizers in the future).
|
74 |
-
"""
|
75 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
76 |
-
raise NotImplementedError()
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
class T5Conditioner(nn.Module):
|
81 |
|
82 |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
@@ -95,12 +22,10 @@ class T5Conditioner(nn.Module):
|
|
95 |
"google/flan-t5-11b": 1024,
|
96 |
}
|
97 |
|
98 |
-
def __init__(self,
|
99 |
-
name
|
100 |
-
output_dim
|
101 |
-
device
|
102 |
-
word_dropout: float = 0.,
|
103 |
-
normalize_text: bool = False,
|
104 |
finetune=False):
|
105 |
print(f'{finetune=}')
|
106 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
@@ -110,19 +35,9 @@ class T5Conditioner(nn.Module):
|
|
110 |
self.output_proj = nn.Linear(self.dim, output_dim)
|
111 |
self.device = device
|
112 |
self.name = name
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
# thanks https://gist.github.com/simon-weber/7853144
|
117 |
-
previous_level = logging.root.manager.disable
|
118 |
-
logging.disable(logging.ERROR)
|
119 |
-
with warnings.catch_warnings():
|
120 |
-
warnings.simplefilter("ignore")
|
121 |
-
try:
|
122 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
123 |
-
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
124 |
-
finally:
|
125 |
-
logging.disable(previous_level)
|
126 |
if finetune:
|
127 |
self.t5 = t5
|
128 |
else:
|
@@ -130,116 +45,27 @@ class T5Conditioner(nn.Module):
|
|
130 |
# of the saved checkpoint
|
131 |
self.__dict__['t5'] = t5.to(device)
|
132 |
|
133 |
-
self.normalize_text = normalize_text
|
134 |
-
if normalize_text:
|
135 |
-
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
|
136 |
|
137 |
-
def tokenize(self, x
|
138 |
-
|
139 |
-
entries
|
140 |
-
if self.normalize_text:
|
141 |
-
_, _, entries = self.text_normalizer(entries, return_text=True)
|
142 |
-
if self.word_dropout > 0. and self.training:
|
143 |
-
new_entries = []
|
144 |
-
for entry in entries:
|
145 |
-
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
|
146 |
-
new_entries.append(" ".join(words))
|
147 |
-
entries = new_entries
|
148 |
|
149 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
150 |
|
151 |
-
inputs = self.t5_tokenizer(entries,
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
155 |
|
156 |
-
def forward(self,
|
157 |
-
|
|
|
|
|
158 |
with torch.no_grad():
|
159 |
-
embeds = self.t5(
|
|
|
|
|
160 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
161 |
-
embeds = (embeds *
|
162 |
-
|
163 |
-
# T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
|
164 |
-
print(f'{embeds.dtype=}') # inputs["input_ids"].shape=torch.Size([2, 4])
|
165 |
-
return embeds, mask
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
class ConditioningProvider(nn.Module):
|
175 |
-
|
176 |
-
def __init__(self,
|
177 |
-
conditioners):
|
178 |
-
super().__init__()
|
179 |
-
self.conditioners = nn.ModuleDict(conditioners)
|
180 |
-
|
181 |
-
@property
|
182 |
-
def text_conditions(self):
|
183 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
188 |
-
output = {}
|
189 |
-
text = self._collate_text(inputs)
|
190 |
-
# wavs = self._collate_wavs(inputs)
|
191 |
-
# joint_embeds = self._collate_joint_embeds(inputs)
|
192 |
-
|
193 |
-
# assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
194 |
-
# f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
195 |
-
# f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
196 |
-
# )
|
197 |
-
for attribute, batch in text.items(): #, joint_embeds.items()):
|
198 |
-
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
199 |
-
print(f'COndProvToknz {output=}\n==')
|
200 |
-
return output
|
201 |
-
|
202 |
-
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
203 |
-
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
204 |
-
The output is for example:
|
205 |
-
{
|
206 |
-
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
207 |
-
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
208 |
-
...
|
209 |
-
}
|
210 |
-
|
211 |
-
Args:
|
212 |
-
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
213 |
-
"""
|
214 |
-
output = {}
|
215 |
-
for attribute, inputs in tokenized.items():
|
216 |
-
condition, mask = self.conditioners[attribute](inputs)
|
217 |
-
output[attribute] = (condition, mask)
|
218 |
-
return output
|
219 |
-
|
220 |
-
def _collate_text(self, samples):
|
221 |
-
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
222 |
-
are the attributes and the values are the aggregated input per attribute.
|
223 |
-
For example:
|
224 |
-
Input:
|
225 |
-
[
|
226 |
-
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
227 |
-
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
228 |
-
]
|
229 |
-
Output:
|
230 |
-
{
|
231 |
-
"genre": ["Rock", "Hip-hop"],
|
232 |
-
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
233 |
-
}
|
234 |
|
235 |
-
|
236 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
237 |
-
Returns:
|
238 |
-
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
239 |
-
"""
|
240 |
-
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
241 |
-
texts = [x.text for x in samples]
|
242 |
-
for text in texts:
|
243 |
-
for condition in self.text_conditions:
|
244 |
-
out[condition].append(text[condition])
|
245 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import warnings
|
2 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
3 |
import torch
|
4 |
from torch import nn
|
|
|
|
|
|
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
class T5Conditioner(nn.Module):
|
8 |
|
9 |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
|
|
22 |
"google/flan-t5-11b": 1024,
|
23 |
}
|
24 |
|
25 |
+
def __init__(self,
|
26 |
+
name,
|
27 |
+
output_dim,
|
28 |
+
device,
|
|
|
|
|
29 |
finetune=False):
|
30 |
print(f'{finetune=}')
|
31 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
|
|
35 |
self.output_proj = nn.Linear(self.dim, output_dim)
|
36 |
self.device = device
|
37 |
self.name = name
|
38 |
+
|
39 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
40 |
+
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if finetune:
|
42 |
self.t5 = t5
|
43 |
else:
|
|
|
45 |
# of the saved checkpoint
|
46 |
self.__dict__['t5'] = t5.to(device)
|
47 |
|
|
|
|
|
|
|
48 |
|
49 |
+
def tokenize(self, x):
|
50 |
+
|
51 |
+
entries = [xi if xi is not None else "" for xi in x]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
|
|
53 |
|
54 |
+
inputs = self.t5_tokenizer(entries,
|
55 |
+
return_tensors='pt',
|
56 |
+
padding=True).to(self.device)
|
57 |
+
|
58 |
+
return inputs # 'input_ids' 'attentio mask'
|
59 |
|
60 |
+
def forward(self, descriptions):
|
61 |
+
|
62 |
+
d = self.tokenize(descriptions)
|
63 |
+
|
64 |
with torch.no_grad():
|
65 |
+
embeds = self.t5(input_ids=d['input_ids'],
|
66 |
+
attention_mask=d['attention_mask']
|
67 |
+
).last_hidden_state # no kvcache for txt conditioning
|
68 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
69 |
+
embeds = (embeds * d['attention_mask'].unsqueeze(-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
return embeds # , d['attention_mask']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/lm.py
CHANGED
@@ -23,17 +23,6 @@ def _shift(x):
|
|
23 |
|
24 |
|
25 |
|
26 |
-
# ============================================== From LM.py
|
27 |
-
|
28 |
-
|
29 |
-
logger = logging.getLogger(__name__)
|
30 |
-
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
31 |
-
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
32 |
-
|
33 |
-
ConditionTensors = tp.Dict[str, ConditionType]
|
34 |
-
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
35 |
-
|
36 |
-
|
37 |
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
38 |
"""LM layer initialization.
|
39 |
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
@@ -280,19 +269,14 @@ class LMModel(nn.Module):
|
|
280 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
281 |
|
282 |
@torch.no_grad()
|
283 |
-
def generate(self,
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
288 |
-
|
289 |
-
|
290 |
-
cfg_conditions = self.condition_provider(tokenized)
|
291 |
-
|
292 |
-
|
293 |
|
|
|
|
|
294 |
# NULL CONDITION
|
295 |
-
text_condition = cfg_conditions['description'][0]
|
296 |
bs, _, _ = text_condition.shape
|
297 |
text_condition = torch.cat(
|
298 |
[
|
@@ -330,7 +314,7 @@ class LMModel(nn.Module):
|
|
330 |
|
331 |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
332 |
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
|
333 |
-
condition_tensors=text_condition,
|
334 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
335 |
|
336 |
|
|
|
23 |
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
27 |
"""LM layer initialization.
|
28 |
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
|
|
269 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
270 |
|
271 |
@torch.no_grad()
|
272 |
+
def generate(self,
|
273 |
+
descriptions = ['windy day', 'rain storm'],
|
274 |
+
max_gen_len = 256):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
+
text_condition = self.condition_provider(descriptions)
|
277 |
+
|
278 |
# NULL CONDITION
|
279 |
+
# text_condition = cfg_conditions['description'][0]
|
280 |
bs, _, _ = text_condition.shape
|
281 |
text_condition = torch.cat(
|
282 |
[
|
|
|
314 |
|
315 |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
316 |
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
|
317 |
+
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
318 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
319 |
|
320 |
|
audiocraft/transformer.py
CHANGED
@@ -4,7 +4,6 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
from torch.nn import functional as F
|
6 |
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
7 |
-
from xformers import ops
|
8 |
|
9 |
|
10 |
_efficient_attention_backend: str = 'torch'
|
@@ -12,7 +11,6 @@ _efficient_attention_backend: str = 'torch'
|
|
12 |
|
13 |
|
14 |
|
15 |
-
|
16 |
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
17 |
if _efficient_attention_backend == 'torch' and memory_efficient:
|
18 |
return 2
|
@@ -190,7 +188,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
190 |
# else:
|
191 |
# bound_layout = "b t p h d"
|
192 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
193 |
-
q, k, v =
|
194 |
|
195 |
|
196 |
if self.k_history is not None:
|
@@ -222,7 +220,6 @@ class StreamingMultiheadAttention(nn.Module):
|
|
222 |
|
223 |
p = self.dropout if self.training else 0
|
224 |
if _efficient_attention_backend == 'torch':
|
225 |
-
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen')
|
226 |
x = torch.nn.functional.scaled_dot_product_attention(
|
227 |
q, k, v, is_causal=False, dropout_p=p
|
228 |
)
|
|
|
4 |
import torch.nn as nn
|
5 |
from torch.nn import functional as F
|
6 |
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
|
|
7 |
|
8 |
|
9 |
_efficient_attention_backend: str = 'torch'
|
|
|
11 |
|
12 |
|
13 |
|
|
|
14 |
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
15 |
if _efficient_attention_backend == 'torch' and memory_efficient:
|
16 |
return 2
|
|
|
188 |
# else:
|
189 |
# bound_layout = "b t p h d"
|
190 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
191 |
+
q, k, v = packed.unbind(dim=2)
|
192 |
|
193 |
|
194 |
if self.k_history is not None:
|
|
|
220 |
|
221 |
p = self.dropout if self.training else 0
|
222 |
if _efficient_attention_backend == 'torch':
|
|
|
223 |
x = torch.nn.functional.scaled_dot_product_attention(
|
224 |
q, k, v, is_causal=False, dropout_p=p
|
225 |
)
|
models.py
CHANGED
@@ -511,7 +511,11 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
511 |
|
512 |
def _load_model(model_config, model_path):
|
513 |
model = ASRCNN(**model_config)
|
514 |
-
params = torch.load(
|
|
|
|
|
|
|
|
|
515 |
model.load_state_dict(params)
|
516 |
return model
|
517 |
|
|
|
511 |
|
512 |
def _load_model(model_config, model_path):
|
513 |
model = ASRCNN(**model_config)
|
514 |
+
params = torch.load(
|
515 |
+
model_path,
|
516 |
+
map_location='cpu',
|
517 |
+
weights_only=False
|
518 |
+
)['model']
|
519 |
model.load_state_dict(params)
|
520 |
return model
|
521 |
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
numpy
|
4 |
+
audiofile
|
5 |
+
audresample
|
6 |
+
cached_path
|
7 |
+
einops
|
8 |
+
flask
|
9 |
+
librosa
|
10 |
+
moviepy
|
11 |
+
sentencepiece
|
12 |
+
omegaconf
|
13 |
+
opencv-python
|
14 |
+
soundfile
|
15 |
+
transformers
|
16 |
+
munch
|
17 |
+
srt
|
18 |
+
nltk
|
19 |
+
phonemizer
|