Gertie01 commited on
Commit
e77217f
1 Parent(s): 13ddd49

Update musiclm_pytorch.py

Browse files
Files changed (1) hide show
  1. musiclm_pytorch.py +10 -4
musiclm_pytorch.py CHANGED
@@ -5,6 +5,7 @@ from torch import nn, einsum
5
  from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
6
 
7
  from audiolm_pytorch import AudioLM
 
8
 
9
  from x_clip.tokenizer import tokenizer
10
  from vector_quantize_pytorch import ResidualVQ
@@ -448,7 +449,7 @@ class MuLaN(nn.Module):
448
  # music lm
449
 
450
  @beartype
451
- class MuLaNEmbedQuantizer(nn.Module):
452
  def __init__(
453
  self,
454
  mulan: MuLaN,
@@ -494,6 +495,9 @@ class MuLaNEmbedQuantizer(nn.Module):
494
 
495
  self.set_default_namespace(namespaces[0])
496
 
 
 
 
497
  def set_default_namespace(self, namespace):
498
  self._default_namespace = namespace
499
 
@@ -537,6 +541,8 @@ class MusicLM(nn.Module):
537
  mulan_embed_quantizer: MuLaNEmbedQuantizer
538
  ):
539
  super().__init__()
 
 
540
  self.mulan_embed_quantizer = mulan_embed_quantizer
541
  self.audio_lm = audio_lm
542
 
@@ -549,7 +555,7 @@ class MusicLM(nn.Module):
549
  self.eval()
550
 
551
  texts = tokenizer.tokenize(raw_texts)
552
- cond_tokens = self.mulan_embed_quantizer(texts = texts)
553
 
554
- wavs = self.audio_lm.generate(cond_tokens = cond_tokens, **audio_lm_kwargs)
555
- return wavs
 
 
5
  from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
6
 
7
  from audiolm_pytorch import AudioLM
8
+ from audiolm_pytorch.utils import AudioConditionerBase
9
 
10
  from x_clip.tokenizer import tokenizer
11
  from vector_quantize_pytorch import ResidualVQ
 
449
  # music lm
450
 
451
  @beartype
452
+ class MuLaNEmbedQuantizer(AudioConditionerBase):
453
  def __init__(
454
  self,
455
  mulan: MuLaN,
 
495
 
496
  self.set_default_namespace(namespaces[0])
497
 
498
+ def parameters(self):
499
+ return self.cond_embeddings.parameters()
500
+
501
  def set_default_namespace(self, namespace):
502
  self._default_namespace = namespace
503
 
 
541
  mulan_embed_quantizer: MuLaNEmbedQuantizer
542
  ):
543
  super().__init__()
544
+ assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis'
545
+
546
  self.mulan_embed_quantizer = mulan_embed_quantizer
547
  self.audio_lm = audio_lm
548
 
 
555
  self.eval()
556
 
557
  texts = tokenizer.tokenize(raw_texts)
 
558
 
559
+ text_embeds = self.mulan_embed_quantizer(texts = texts)
560
+
561
+ return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)