Dionyssos commited on
Commit
86b9ce4
·
1 Parent(s): df63ff0

fx sounds batch inference

Browse files
README.md CHANGED
@@ -62,7 +62,7 @@ pip install -r requirements.txt
62
  Flask `tmux-session`
63
 
64
  ```
65
- CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=./hf_home CUDA_VISIBLE_DEVICES=0 python api.py
66
  ```
67
 
68
  Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
@@ -127,5 +127,5 @@ Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.y
127
 
128
  ```python
129
  # audiobook will be saved in ./tts_audiobooks
130
- python audiobook.py
131
  ```
 
62
  Flask `tmux-session`
63
 
64
  ```
65
+ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=0 python api.py
66
  ```
67
 
68
  Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
 
127
 
128
  ```python
129
  # audiobook will be saved in ./tts_audiobooks
130
+ python audiobook.py
131
  ```
api.py CHANGED
@@ -9,12 +9,11 @@ import re
9
  import srt
10
  import subprocess
11
  import cv2
12
- import markdown
13
  from pathlib import Path
14
  from types import SimpleNamespace
15
  from flask import Flask, request, send_from_directory
16
- from flask_cors import CORS
17
- from moviepy.editor import *
18
  from audiocraft.builders import AudioGen
19
  CACHE_DIR = 'flask_cache/'
20
  NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
@@ -79,10 +78,10 @@ def overlay(x, soundscape=None):
79
  background = sound_generator.generate(
80
  [soundscape] * NUM_SOUND_GENERATIONS
81
  ).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
82
- # sound_generator._flush() # ALREADY done in lm.generate() THE Encodec does not SEEM TO HAVE TRANSFORMERS thys no kvclean up kv cache from previous soundscape
83
- # upsample 16 kHz AudioGen to 24kHZ StyleTTS
84
 
85
- print('Resampling')
86
 
87
 
88
  background = audresample.resample(
@@ -178,14 +177,6 @@ def tts_multi_sentence(precomputed_style_vector=None,
178
  # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
179
 
180
  app = Flask(__name__)
181
- cors = CORS(app)
182
-
183
-
184
- @app.route("/")
185
- def index():
186
- with open('README.md', 'r') as f:
187
- return markdown.markdown(f.read())
188
-
189
 
190
  @app.route("/", methods=['GET', 'POST', 'PUT'])
191
  def serve_wav():
@@ -460,7 +451,8 @@ def serve_wav():
460
 
461
  # SILENT CLIP
462
 
463
- clip_silent = ImageClip(STATIC_FRAME).set_duration(5) # as long as the audio - TTS first
 
464
  clip_silent.write_videofile(SILENT_VIDEO, fps=24)
465
 
466
  x = tts_multi_sentence(text=text,
 
9
  import srt
10
  import subprocess
11
  import cv2
 
12
  from pathlib import Path
13
  from types import SimpleNamespace
14
  from flask import Flask, request, send_from_directory
15
+ from moviepy.video.io.VideoFileClip import VideoFileClip
16
+ from moviepy.video.VideoClip import ImageClip
17
  from audiocraft.builders import AudioGen
18
  CACHE_DIR = 'flask_cache/'
19
  NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
 
78
  background = sound_generator.generate(
79
  [soundscape] * NUM_SOUND_GENERATIONS
80
  ).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
81
+
82
+ # upsample 16 kHz AudioGen to 24kHZ of VITS/StyleTTS2
83
 
84
+ print('Resampling') # soundscape each generation in batch differs from the other generations thus clone/shift each element in batch, finally concat w/o shift
85
 
86
 
87
  background = audresample.resample(
 
177
  # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
178
 
179
  app = Flask(__name__)
 
 
 
 
 
 
 
 
180
 
181
  @app.route("/", methods=['GET', 'POST', 'PUT'])
182
  def serve_wav():
 
451
 
452
  # SILENT CLIP
453
 
454
+ clip_silent = ImageClip(img=STATIC_FRAME,
455
+ duration=5) # ffmpeg continues this silent video for duration of TTS
456
  clip_silent.write_videofile(SILENT_VIDEO, fps=24)
457
 
458
  x = tts_multi_sentence(text=text,
audiocraft/builders.py CHANGED
@@ -10,11 +10,7 @@ from .encodec import EncodecModel
10
  from .lm import LMModel
11
  from .seanet import SEANetDecoder
12
  from .codebooks_patterns import DelayedPatternProvider
13
- from .conditioners import (
14
- ConditioningProvider,
15
- T5Conditioner,
16
- ConditioningAttributes
17
- )
18
  from .vq import ResidualVectorQuantizer
19
 
20
 
@@ -73,10 +69,8 @@ class AudioGen(nn.Module):
73
  def generate(self,
74
  descriptions):
75
  with torch.no_grad():
76
- attributes = [
77
- ConditioningAttributes(text={'description': d}) for d in descriptions]
78
  gen_tokens = self.lm.generate(
79
- conditions=attributes,
80
  max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
81
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
82
  # print('______________\nAudioGen Tokens', gen_tokens)
@@ -144,10 +138,8 @@ class AudioGen(nn.Module):
144
  codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
145
  attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
146
  cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
147
- cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
148
 
149
- condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
150
- ).to(self.device)
151
 
152
 
153
  # if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
@@ -163,7 +155,7 @@ class AudioGen(nn.Module):
163
  pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
164
  return LMModel(
165
  pattern_provider=pattern_provider,
166
- condition_provider=condition_provider,
167
  cfg_dropout=cfg_prob,
168
  cfg_coef=cfg_coef,
169
  attribute_dropout=attribute_dropout,
@@ -173,34 +165,8 @@ class AudioGen(nn.Module):
173
  ).to(cfg.device)
174
  else:
175
  raise KeyError(f"Unexpected LM model {cfg.lm_model}")
176
-
177
-
178
- def get_conditioner_provider(self, output_dim,
179
- cfg):
180
- """Instantiate T5 text"""
181
- cfg = getattr(cfg, 'conditioners')
182
- dict_cfg = {} if cfg is None else dict_from_config(cfg)
183
- conditioners={}
184
- condition_provider_args = dict_cfg.pop('args', {})
185
- condition_provider_args.pop('merge_text_conditions_p', None)
186
- condition_provider_args.pop('drop_desc_p', None)
187
-
188
- for cond, cond_cfg in dict_cfg.items():
189
- model_type = cond_cfg['model']
190
- model_args = cond_cfg[model_type]
191
- if model_type == 't5':
192
- conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
193
- device=self.device,
194
- **model_args)
195
- else:
196
- raise ValueError(f"Unrecognized conditioning model: {model_type}")
197
-
198
- # print(f'{condition_provider_args=}')
199
- return ConditioningProvider(conditioners)
200
-
201
-
202
-
203
-
204
  def get_codebooks_pattern_provider(self, n_q, cfg):
205
  pattern_providers = {
206
  'delay': DelayedPatternProvider, # THIS
@@ -249,6 +215,10 @@ class AudioGen(nn.Module):
249
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
250
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
251
  model = self.get_lm_model(cfg)
 
 
 
 
252
  model.load_state_dict(pkg['best_state'])
253
  model.cfg = cfg
254
  # return model
 
10
  from .lm import LMModel
11
  from .seanet import SEANetDecoder
12
  from .codebooks_patterns import DelayedPatternProvider
13
+ from .conditioners import T5Conditioner
 
 
 
 
14
  from .vq import ResidualVectorQuantizer
15
 
16
 
 
69
  def generate(self,
70
  descriptions):
71
  with torch.no_grad():
 
 
72
  gen_tokens = self.lm.generate(
73
+ descriptions=descriptions,
74
  max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
75
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
76
  # print('______________\nAudioGen Tokens', gen_tokens)
 
138
  codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
139
  attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
140
  cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
141
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
142
 
 
 
143
 
144
 
145
  # if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
 
155
  pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
156
  return LMModel(
157
  pattern_provider=pattern_provider,
158
+ condition_provider=T5Conditioner(name='t5-large', output_dim=kwargs["dim"], device=self.device),
159
  cfg_dropout=cfg_prob,
160
  cfg_coef=cfg_coef,
161
  attribute_dropout=attribute_dropout,
 
165
  ).to(cfg.device)
166
  else:
167
  raise KeyError(f"Unexpected LM model {cfg.lm_model}")
168
+
169
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def get_codebooks_pattern_provider(self, n_q, cfg):
171
  pattern_providers = {
172
  'delay': DelayedPatternProvider, # THIS
 
215
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
216
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
217
  model = self.get_lm_model(cfg)
218
+
219
+ _best = pkg['best_state']
220
+ _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
221
+ _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
222
  model.load_state_dict(pkg['best_state'])
223
  model.cfg = cfg
224
  # return model
audiocraft/conditioners.py CHANGED
@@ -1,82 +1,9 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass, field
3
- import logging
4
- import random
5
- import typing as tp
6
  import warnings
7
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
8
  import torch
9
  from torch import nn
10
- logger = logging.getLogger(__name__)
11
- TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
12
- ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
13
 
14
 
15
-
16
-
17
- class JointEmbedCondition(tp.NamedTuple):
18
- wav: torch.Tensor
19
- text: tp.List[tp.Optional[str]]
20
- length: torch.Tensor
21
- sample_rate: tp.List[int]
22
- path: tp.List[tp.Optional[str]] = []
23
- seek_time: tp.List[tp.Optional[float]] = []
24
-
25
-
26
- @dataclass
27
- class ConditioningAttributes:
28
- text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
29
- wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
30
- joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
31
-
32
- def __getitem__(self, item):
33
- return getattr(self, item)
34
-
35
- @property
36
- def text_attributes(self):
37
- return self.text.keys()
38
-
39
- @property
40
- def wav_attributes(self):
41
- return self.wav.keys()
42
-
43
- @property
44
- def joint_embed_attributes(self):
45
- return self.joint_embed.keys()
46
-
47
- @property
48
- def attributes(self):
49
- return {
50
- "text": self.text_attributes,
51
- "wav": self.wav_attributes,
52
- "joint_embed": self.joint_embed_attributes,
53
- }
54
-
55
- def to_flat_dict(self):
56
- return {
57
- **{f"text.{k}": v for k, v in self.text.items()},
58
- **{f"wav.{k}": v for k, v in self.wav.items()},
59
- **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
60
- }
61
-
62
- @classmethod
63
- def from_flat_dict(cls, x):
64
- out = cls()
65
- for k, v in x.items():
66
- kind, att = k.split(".")
67
- out[kind][att] = v
68
- return out
69
-
70
-
71
- class Tokenizer:
72
- """Base tokenizer implementation
73
- (in case we want to introduce more advances tokenizers in the future).
74
- """
75
- def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
76
- raise NotImplementedError()
77
-
78
-
79
-
80
  class T5Conditioner(nn.Module):
81
 
82
  MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
@@ -95,12 +22,10 @@ class T5Conditioner(nn.Module):
95
  "google/flan-t5-11b": 1024,
96
  }
97
 
98
- def __init__(self,
99
- name: str,
100
- output_dim: int,
101
- device: str,
102
- word_dropout: float = 0.,
103
- normalize_text: bool = False,
104
  finetune=False):
105
  print(f'{finetune=}')
106
  assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
@@ -110,19 +35,9 @@ class T5Conditioner(nn.Module):
110
  self.output_proj = nn.Linear(self.dim, output_dim)
111
  self.device = device
112
  self.name = name
113
- self.word_dropout = word_dropout
114
-
115
- # Let's disable logging temporarily because T5 will vomit some errors otherwise.
116
- # thanks https://gist.github.com/simon-weber/7853144
117
- previous_level = logging.root.manager.disable
118
- logging.disable(logging.ERROR)
119
- with warnings.catch_warnings():
120
- warnings.simplefilter("ignore")
121
- try:
122
- self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
123
- t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
124
- finally:
125
- logging.disable(previous_level)
126
  if finetune:
127
  self.t5 = t5
128
  else:
@@ -130,116 +45,27 @@ class T5Conditioner(nn.Module):
130
  # of the saved checkpoint
131
  self.__dict__['t5'] = t5.to(device)
132
 
133
- self.normalize_text = normalize_text
134
- if normalize_text:
135
- self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
136
 
137
- def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
138
- # if current sample doesn't have a certain attribute, replace with empty string
139
- entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
140
- if self.normalize_text:
141
- _, _, entries = self.text_normalizer(entries, return_text=True)
142
- if self.word_dropout > 0. and self.training:
143
- new_entries = []
144
- for entry in entries:
145
- words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
146
- new_entries.append(" ".join(words))
147
- entries = new_entries
148
 
149
- empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
150
 
151
- inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
152
- mask = inputs['attention_mask']
153
- mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
154
- return inputs
 
155
 
156
- def forward(self, inputs):
157
- mask = inputs['attention_mask']
 
 
158
  with torch.no_grad():
159
- embeds = self.t5(**inputs).last_hidden_state
 
 
160
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
161
- embeds = (embeds * mask.unsqueeze(-1))
162
-
163
- # T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
164
- print(f'{embeds.dtype=}') # inputs["input_ids"].shape=torch.Size([2, 4])
165
- return embeds, mask
166
-
167
-
168
-
169
-
170
-
171
-
172
-
173
-
174
- class ConditioningProvider(nn.Module):
175
-
176
- def __init__(self,
177
- conditioners):
178
- super().__init__()
179
- self.conditioners = nn.ModuleDict(conditioners)
180
-
181
- @property
182
- def text_conditions(self):
183
- return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
184
-
185
-
186
-
187
- def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
188
- output = {}
189
- text = self._collate_text(inputs)
190
- # wavs = self._collate_wavs(inputs)
191
- # joint_embeds = self._collate_joint_embeds(inputs)
192
-
193
- # assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
194
- # f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
195
- # f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
196
- # )
197
- for attribute, batch in text.items(): #, joint_embeds.items()):
198
- output[attribute] = self.conditioners[attribute].tokenize(batch)
199
- print(f'COndProvToknz {output=}\n==')
200
- return output
201
-
202
- def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
203
- """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
204
- The output is for example:
205
- {
206
- "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
207
- "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
208
- ...
209
- }
210
-
211
- Args:
212
- tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
213
- """
214
- output = {}
215
- for attribute, inputs in tokenized.items():
216
- condition, mask = self.conditioners[attribute](inputs)
217
- output[attribute] = (condition, mask)
218
- return output
219
-
220
- def _collate_text(self, samples):
221
- """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
222
- are the attributes and the values are the aggregated input per attribute.
223
- For example:
224
- Input:
225
- [
226
- ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
227
- ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
228
- ]
229
- Output:
230
- {
231
- "genre": ["Rock", "Hip-hop"],
232
- "description": ["A rock song with a guitar solo", "A hip-hop verse"]
233
- }
234
 
235
- Args:
236
- samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
237
- Returns:
238
- dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
239
- """
240
- out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
241
- texts = [x.text for x in samples]
242
- for text in texts:
243
- for condition in self.text_conditions:
244
- out[condition].append(text[condition])
245
- return out
 
 
 
 
 
 
1
  import warnings
2
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
3
  import torch
4
  from torch import nn
 
 
 
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class T5Conditioner(nn.Module):
8
 
9
  MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
 
22
  "google/flan-t5-11b": 1024,
23
  }
24
 
25
+ def __init__(self,
26
+ name,
27
+ output_dim,
28
+ device,
 
 
29
  finetune=False):
30
  print(f'{finetune=}')
31
  assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
 
35
  self.output_proj = nn.Linear(self.dim, output_dim)
36
  self.device = device
37
  self.name = name
38
+
39
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
40
+ t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
 
 
 
 
 
 
 
 
 
 
41
  if finetune:
42
  self.t5 = t5
43
  else:
 
45
  # of the saved checkpoint
46
  self.__dict__['t5'] = t5.to(device)
47
 
 
 
 
48
 
49
+ def tokenize(self, x):
50
+
51
+ entries = [xi if xi is not None else "" for xi in x]
 
 
 
 
 
 
 
 
52
 
 
53
 
54
+ inputs = self.t5_tokenizer(entries,
55
+ return_tensors='pt',
56
+ padding=True).to(self.device)
57
+
58
+ return inputs # 'input_ids' 'attentio mask'
59
 
60
+ def forward(self, descriptions):
61
+
62
+ d = self.tokenize(descriptions)
63
+
64
  with torch.no_grad():
65
+ embeds = self.t5(input_ids=d['input_ids'],
66
+ attention_mask=d['attention_mask']
67
+ ).last_hidden_state # no kvcache for txt conditioning
68
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
69
+ embeds = (embeds * d['attention_mask'].unsqueeze(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return embeds # , d['attention_mask']
 
 
 
 
 
 
 
 
 
 
audiocraft/lm.py CHANGED
@@ -23,17 +23,6 @@ def _shift(x):
23
 
24
 
25
 
26
- # ============================================== From LM.py
27
-
28
-
29
- logger = logging.getLogger(__name__)
30
- TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
31
- ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
32
-
33
- ConditionTensors = tp.Dict[str, ConditionType]
34
- CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
35
-
36
-
37
  def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
38
  """LM layer initialization.
39
  Inspired from xlformers: https://github.com/fairinternal/xlformers
@@ -280,19 +269,14 @@ class LMModel(nn.Module):
280
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
281
 
282
  @torch.no_grad()
283
- def generate(self, conditions = [],
284
- max_gen_len=256):
285
-
286
-
287
- tokenized = self.condition_provider.tokenize(conditions)
288
-
289
-
290
- cfg_conditions = self.condition_provider(tokenized)
291
-
292
-
293
 
 
 
294
  # NULL CONDITION
295
- text_condition = cfg_conditions['description'][0]
296
  bs, _, _ = text_condition.shape
297
  text_condition = torch.cat(
298
  [
@@ -330,7 +314,7 @@ class LMModel(nn.Module):
330
 
331
  # forward duplicates the query to nullcond - then cfg & returns deduplicate token
332
  next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
333
- condition_tensors=text_condition,
334
  token_count=offset-1) # [bs, 4, 1, 2048]
335
 
336
 
 
23
 
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
27
  """LM layer initialization.
28
  Inspired from xlformers: https://github.com/fairinternal/xlformers
 
269
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
270
 
271
  @torch.no_grad()
272
+ def generate(self,
273
+ descriptions = ['windy day', 'rain storm'],
274
+ max_gen_len = 256):
 
 
 
 
 
 
 
275
 
276
+ text_condition = self.condition_provider(descriptions)
277
+
278
  # NULL CONDITION
279
+ # text_condition = cfg_conditions['description'][0]
280
  bs, _, _ = text_condition.shape
281
  text_condition = torch.cat(
282
  [
 
314
 
315
  # forward duplicates the query to nullcond - then cfg & returns deduplicate token
316
  next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
317
+ condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
318
  token_count=offset-1) # [bs, 4, 1, 2048]
319
 
320
 
audiocraft/transformer.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
  from torch.utils.checkpoint import checkpoint as torch_checkpoint
7
- from xformers import ops
8
 
9
 
10
  _efficient_attention_backend: str = 'torch'
@@ -12,7 +11,6 @@ _efficient_attention_backend: str = 'torch'
12
 
13
 
14
 
15
-
16
  def _get_attention_time_dimension(memory_efficient: bool) -> int:
17
  if _efficient_attention_backend == 'torch' and memory_efficient:
18
  return 2
@@ -190,7 +188,7 @@ class StreamingMultiheadAttention(nn.Module):
190
  # else:
191
  # bound_layout = "b t p h d"
192
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
193
- q, k, v = ops.unbind(packed, dim=2)
194
 
195
 
196
  if self.k_history is not None:
@@ -222,7 +220,6 @@ class StreamingMultiheadAttention(nn.Module):
222
 
223
  p = self.dropout if self.training else 0
224
  if _efficient_attention_backend == 'torch':
225
- # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen')
226
  x = torch.nn.functional.scaled_dot_product_attention(
227
  q, k, v, is_causal=False, dropout_p=p
228
  )
 
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
  from torch.utils.checkpoint import checkpoint as torch_checkpoint
 
7
 
8
 
9
  _efficient_attention_backend: str = 'torch'
 
11
 
12
 
13
 
 
14
  def _get_attention_time_dimension(memory_efficient: bool) -> int:
15
  if _efficient_attention_backend == 'torch' and memory_efficient:
16
  return 2
 
188
  # else:
189
  # bound_layout = "b t p h d"
190
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
191
+ q, k, v = packed.unbind(dim=2)
192
 
193
 
194
  if self.k_history is not None:
 
220
 
221
  p = self.dropout if self.training else 0
222
  if _efficient_attention_backend == 'torch':
 
223
  x = torch.nn.functional.scaled_dot_product_attention(
224
  q, k, v, is_causal=False, dropout_p=p
225
  )
models.py CHANGED
@@ -511,7 +511,11 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
511
 
512
  def _load_model(model_config, model_path):
513
  model = ASRCNN(**model_config)
514
- params = torch.load(model_path, map_location='cpu')['model']
 
 
 
 
515
  model.load_state_dict(params)
516
  return model
517
 
 
511
 
512
  def _load_model(model_config, model_path):
513
  model = ASRCNN(**model_config)
514
+ params = torch.load(
515
+ model_path,
516
+ map_location='cpu',
517
+ weights_only=False
518
+ )['model']
519
  model.load_state_dict(params)
520
  return model
521
 
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ numpy
4
+ audiofile
5
+ audresample
6
+ cached_path
7
+ einops
8
+ flask
9
+ librosa
10
+ moviepy
11
+ sentencepiece
12
+ omegaconf
13
+ opencv-python
14
+ soundfile
15
+ transformers
16
+ munch
17
+ srt
18
+ nltk
19
+ phonemizer