Update modeling_diva.py
Browse files- modeling_diva.py +3 -3
modeling_diva.py
CHANGED
@@ -179,7 +179,7 @@ class DiVAModel(PreTrainedModel):
|
|
179 |
return outputs
|
180 |
|
181 |
def generate(
|
182 |
-
self, audio,
|
183 |
):
|
184 |
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
|
185 |
input_features = inputs.input_features.to(self.speech_encoder_device)
|
@@ -191,9 +191,9 @@ class DiVAModel(PreTrainedModel):
|
|
191 |
output_device=self.llama_decoder.model.embed_tokens.weight.device,
|
192 |
).squeeze()
|
193 |
|
194 |
-
if
|
195 |
user_prompt_text = torch.tensor(
|
196 |
-
self.tokenizer(
|
197 |
device=self.pre_user_suffix.device,
|
198 |
)
|
199 |
prefix = torch.cat(
|
|
|
179 |
return outputs
|
180 |
|
181 |
def generate(
|
182 |
+
self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128
|
183 |
):
|
184 |
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
|
185 |
input_features = inputs.input_features.to(self.speech_encoder_device)
|
|
|
191 |
output_device=self.llama_decoder.model.embed_tokens.weight.device,
|
192 |
).squeeze()
|
193 |
|
194 |
+
if text_prompt != None and text_prompt != "":
|
195 |
user_prompt_text = torch.tensor(
|
196 |
+
self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
|
197 |
device=self.pre_user_suffix.device,
|
198 |
)
|
199 |
prefix = torch.cat(
|