supermomo668 commited on
Commit
800934b
1 Parent(s): a7b15b1
Files changed (1) hide show
  1. handler.py +35 -17
handler.py CHANGED
@@ -9,7 +9,7 @@ from audiocraft.models import MusicGen
9
 
10
  import yaml
11
  import math
12
- # import torchaudio
13
  import torch
14
 
15
  def get_bip_bip(
@@ -46,13 +46,15 @@ class generator:
46
  duration=self.conf['duration']
47
  )
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
- self.model.to(device)
50
- self.sampling_rate = self.model.config.audio_encoder.sampling_rate
51
 
52
- def preprocess(self, text, audio):
53
- audio = audio[: int(len(audio) // self.conf['nth_slice_prompt'])]
 
 
54
 
55
- def generate(self, text:list, audio: np.array, **kwargs):
56
  """
57
  text: ["modern melodic electronic dance music", "80s blues track with groovy saxophone"]
58
  audio (np.array)
@@ -64,14 +66,27 @@ class generator:
64
  # padding=True,
65
  # return_tensors="pt",
66
  # )
67
- output = self.model.generate_with_chroma(
68
- descriptions=[
69
- text
70
- ],
71
- melody_wavs=audio,
72
- melody_sample_rate=self.conf['sampling_rate'],
73
- progress=True
74
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return output
76
 
77
 
@@ -80,7 +95,7 @@ class EndpointHandler:
80
  # load model and processor from path
81
  # self.model = MusicGen.from_pretrained(
82
  # path, torch_dtype=torch.float16).to("cuda")
83
- self.generator = generator('.conf/generation_conf.yaml')
84
 
85
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
86
  """
@@ -88,12 +103,15 @@ class EndpointHandler:
88
  data (:dict:):
89
  The payload with the text prompt and generation parameters.
90
  """
91
- prompt_duration = 2
92
  # process input
93
  text = data.pop("text", data)
94
  audio = data.pop("audio", data)
95
  parameters = data.pop("parameters", None)
96
- audio, sr = sf.read(io.BytesIO(audio))
 
 
 
97
  output = self.generate(text, audio, sr)
98
 
99
  # # pass inputs with all kwargs in data
 
9
 
10
  import yaml
11
  import math
12
+ import torchaudio
13
  import torch
14
 
15
  def get_bip_bip(
 
46
  duration=self.conf['duration']
47
  )
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ # self.model.to(device)
50
+ self.sampling_rate = self.model.sample_rate # config.audio_encoder.sampling_rate
51
 
52
+ def preprocess(self, text, audio=None):
53
+ if audio is not None:
54
+ audio = audio[: int(len(audio) // self.conf['nth_slice_prompt'])]
55
+ return text, audio
56
 
57
+ def generate(self, text:list, audio: np.array=None, **kwargs):
58
  """
59
  text: ["modern melodic electronic dance music", "80s blues track with groovy saxophone"]
60
  audio (np.array)
 
66
  # padding=True,
67
  # return_tensors="pt",
68
  # )
69
+ if kwargs.get('sr'):
70
+ sr = kwargs.get('sr')
71
+ else:
72
+ sr = self.conf['sampling_rate']
73
+ print(f"Generating from: Text:{text is not None} | audio:{audio is not None}")
74
+ text, audio = self.preprocess(text, audio)
75
+ if self.conf['model'] == 'melody' and audio is not None:
76
+ output = self.model.generate_with_chroma(
77
+ descriptions=[
78
+ text
79
+ ],
80
+ melody_wavs=audio,
81
+ melody_sample_rate=sr,
82
+ # progress=True
83
+ )
84
+ else:
85
+ output = self.model.generate_continuation(
86
+ get_bip_bip(0.125), # .expand(2, -1, -1),
87
+ 32000, text,
88
+ # progress_bar=True
89
+ )
90
  return output
91
 
92
 
 
95
  # load model and processor from path
96
  # self.model = MusicGen.from_pretrained(
97
  # path, torch_dtype=torch.float16).to("cuda")
98
+ self.generator = generator(os.path.join(path, '.conf/generation_conf.yaml'))
99
 
100
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
101
  """
 
103
  data (:dict:):
104
  The payload with the text prompt and generation parameters.
105
  """
106
+ # prompt_duration = 2
107
  # process input
108
  text = data.pop("text", data)
109
  audio = data.pop("audio", data)
110
  parameters = data.pop("parameters", None)
111
+
112
+ audio, sr = torchaudio.load(audio)
113
+ audio = audio.unsqueeze(0)
114
+ # audio, sr = sf.read(io.BytesIO(audio))
115
  output = self.generate(text, audio, sr)
116
 
117
  # # pass inputs with all kwargs in data