supermomo668
commited on
Commit
•
800934b
1
Parent(s):
a7b15b1
stable
Browse files- handler.py +35 -17
handler.py
CHANGED
@@ -9,7 +9,7 @@ from audiocraft.models import MusicGen
|
|
9 |
|
10 |
import yaml
|
11 |
import math
|
12 |
-
|
13 |
import torch
|
14 |
|
15 |
def get_bip_bip(
|
@@ -46,13 +46,15 @@ class generator:
|
|
46 |
duration=self.conf['duration']
|
47 |
)
|
48 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
-
self.model.to(device)
|
50 |
-
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
51 |
|
52 |
-
def preprocess(self, text, audio):
|
53 |
-
|
|
|
|
|
54 |
|
55 |
-
def generate(self, text:list, audio: np.array, **kwargs):
|
56 |
"""
|
57 |
text: ["modern melodic electronic dance music", "80s blues track with groovy saxophone"]
|
58 |
audio (np.array)
|
@@ -64,14 +66,27 @@ class generator:
|
|
64 |
# padding=True,
|
65 |
# return_tensors="pt",
|
66 |
# )
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
]
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
return output
|
76 |
|
77 |
|
@@ -80,7 +95,7 @@ class EndpointHandler:
|
|
80 |
# load model and processor from path
|
81 |
# self.model = MusicGen.from_pretrained(
|
82 |
# path, torch_dtype=torch.float16).to("cuda")
|
83 |
-
self.generator = generator('.conf/generation_conf.yaml')
|
84 |
|
85 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
86 |
"""
|
@@ -88,12 +103,15 @@ class EndpointHandler:
|
|
88 |
data (:dict:):
|
89 |
The payload with the text prompt and generation parameters.
|
90 |
"""
|
91 |
-
prompt_duration = 2
|
92 |
# process input
|
93 |
text = data.pop("text", data)
|
94 |
audio = data.pop("audio", data)
|
95 |
parameters = data.pop("parameters", None)
|
96 |
-
|
|
|
|
|
|
|
97 |
output = self.generate(text, audio, sr)
|
98 |
|
99 |
# # pass inputs with all kwargs in data
|
|
|
9 |
|
10 |
import yaml
|
11 |
import math
|
12 |
+
import torchaudio
|
13 |
import torch
|
14 |
|
15 |
def get_bip_bip(
|
|
|
46 |
duration=self.conf['duration']
|
47 |
)
|
48 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
# self.model.to(device)
|
50 |
+
self.sampling_rate = self.model.sample_rate # config.audio_encoder.sampling_rate
|
51 |
|
52 |
+
def preprocess(self, text, audio=None):
|
53 |
+
if audio is not None:
|
54 |
+
audio = audio[: int(len(audio) // self.conf['nth_slice_prompt'])]
|
55 |
+
return text, audio
|
56 |
|
57 |
+
def generate(self, text:list, audio: np.array=None, **kwargs):
|
58 |
"""
|
59 |
text: ["modern melodic electronic dance music", "80s blues track with groovy saxophone"]
|
60 |
audio (np.array)
|
|
|
66 |
# padding=True,
|
67 |
# return_tensors="pt",
|
68 |
# )
|
69 |
+
if kwargs.get('sr'):
|
70 |
+
sr = kwargs.get('sr')
|
71 |
+
else:
|
72 |
+
sr = self.conf['sampling_rate']
|
73 |
+
print(f"Generating from: Text:{text is not None} | audio:{audio is not None}")
|
74 |
+
text, audio = self.preprocess(text, audio)
|
75 |
+
if self.conf['model'] == 'melody' and audio is not None:
|
76 |
+
output = self.model.generate_with_chroma(
|
77 |
+
descriptions=[
|
78 |
+
text
|
79 |
+
],
|
80 |
+
melody_wavs=audio,
|
81 |
+
melody_sample_rate=sr,
|
82 |
+
# progress=True
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
output = self.model.generate_continuation(
|
86 |
+
get_bip_bip(0.125), # .expand(2, -1, -1),
|
87 |
+
32000, text,
|
88 |
+
# progress_bar=True
|
89 |
+
)
|
90 |
return output
|
91 |
|
92 |
|
|
|
95 |
# load model and processor from path
|
96 |
# self.model = MusicGen.from_pretrained(
|
97 |
# path, torch_dtype=torch.float16).to("cuda")
|
98 |
+
self.generator = generator(os.path.join(path, '.conf/generation_conf.yaml'))
|
99 |
|
100 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
101 |
"""
|
|
|
103 |
data (:dict:):
|
104 |
The payload with the text prompt and generation parameters.
|
105 |
"""
|
106 |
+
# prompt_duration = 2
|
107 |
# process input
|
108 |
text = data.pop("text", data)
|
109 |
audio = data.pop("audio", data)
|
110 |
parameters = data.pop("parameters", None)
|
111 |
+
|
112 |
+
audio, sr = torchaudio.load(audio)
|
113 |
+
audio = audio.unsqueeze(0)
|
114 |
+
# audio, sr = sf.read(io.BytesIO(audio))
|
115 |
output = self.generate(text, audio, sr)
|
116 |
|
117 |
# # pass inputs with all kwargs in data
|