Update musiclm_pytorch.py
Browse files- musiclm_pytorch.py +10 -4
musiclm_pytorch.py
CHANGED
@@ -5,6 +5,7 @@ from torch import nn, einsum
|
|
5 |
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
|
6 |
|
7 |
from audiolm_pytorch import AudioLM
|
|
|
8 |
|
9 |
from x_clip.tokenizer import tokenizer
|
10 |
from vector_quantize_pytorch import ResidualVQ
|
@@ -448,7 +449,7 @@ class MuLaN(nn.Module):
|
|
448 |
# music lm
|
449 |
|
450 |
@beartype
|
451 |
-
class MuLaNEmbedQuantizer(
|
452 |
def __init__(
|
453 |
self,
|
454 |
mulan: MuLaN,
|
@@ -494,6 +495,9 @@ class MuLaNEmbedQuantizer(nn.Module):
|
|
494 |
|
495 |
self.set_default_namespace(namespaces[0])
|
496 |
|
|
|
|
|
|
|
497 |
def set_default_namespace(self, namespace):
|
498 |
self._default_namespace = namespace
|
499 |
|
@@ -537,6 +541,8 @@ class MusicLM(nn.Module):
|
|
537 |
mulan_embed_quantizer: MuLaNEmbedQuantizer
|
538 |
):
|
539 |
super().__init__()
|
|
|
|
|
540 |
self.mulan_embed_quantizer = mulan_embed_quantizer
|
541 |
self.audio_lm = audio_lm
|
542 |
|
@@ -549,7 +555,7 @@ class MusicLM(nn.Module):
|
|
549 |
self.eval()
|
550 |
|
551 |
texts = tokenizer.tokenize(raw_texts)
|
552 |
-
cond_tokens = self.mulan_embed_quantizer(texts = texts)
|
553 |
|
554 |
-
|
555 |
-
|
|
|
|
5 |
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
|
6 |
|
7 |
from audiolm_pytorch import AudioLM
|
8 |
+
from audiolm_pytorch.utils import AudioConditionerBase
|
9 |
|
10 |
from x_clip.tokenizer import tokenizer
|
11 |
from vector_quantize_pytorch import ResidualVQ
|
|
|
449 |
# music lm
|
450 |
|
451 |
@beartype
|
452 |
+
class MuLaNEmbedQuantizer(AudioConditionerBase):
|
453 |
def __init__(
|
454 |
self,
|
455 |
mulan: MuLaN,
|
|
|
495 |
|
496 |
self.set_default_namespace(namespaces[0])
|
497 |
|
498 |
+
def parameters(self):
|
499 |
+
return self.cond_embeddings.parameters()
|
500 |
+
|
501 |
def set_default_namespace(self, namespace):
|
502 |
self._default_namespace = namespace
|
503 |
|
|
|
541 |
mulan_embed_quantizer: MuLaNEmbedQuantizer
|
542 |
):
|
543 |
super().__init__()
|
544 |
+
assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis'
|
545 |
+
|
546 |
self.mulan_embed_quantizer = mulan_embed_quantizer
|
547 |
self.audio_lm = audio_lm
|
548 |
|
|
|
555 |
self.eval()
|
556 |
|
557 |
texts = tokenizer.tokenize(raw_texts)
|
|
|
558 |
|
559 |
+
text_embeds = self.mulan_embed_quantizer(texts = texts)
|
560 |
+
|
561 |
+
return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)
|