Replace the inplace operation
Browse files- modeling_minicpmo.py +23 -7
modeling_minicpmo.py
CHANGED
@@ -377,6 +377,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
377 |
else:
|
378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
379 |
|
|
|
|
|
380 |
vision_hidden_states = [
|
381 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
382 |
]
|
@@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
392 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
393 |
).to(vllm_embedding.device)
|
394 |
|
395 |
-
cur_vllm_emb.
|
396 |
0,
|
397 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
398 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
399 |
)
|
|
|
400 |
elif self.training:
|
401 |
-
|
402 |
|
403 |
-
return
|
404 |
|
405 |
def get_audio_embedding_streaming(self, data):
|
406 |
r"""
|
@@ -463,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
463 |
else:
|
464 |
return []
|
465 |
|
466 |
-
def get_audio_embedding(self, data, chunk_length=-1):
|
467 |
r"""
|
468 |
Extract full audio embeddings with optional chunk-based attention.
|
469 |
|
@@ -481,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
481 |
Returns:
|
482 |
List[List[torch.Tensor]]: audio embeddings
|
483 |
"""
|
|
|
|
|
484 |
|
485 |
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
486 |
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
@@ -541,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
541 |
idx += 1
|
542 |
final_audio_embeds.append(target_audio_embeds)
|
543 |
return final_audio_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
else:
|
545 |
return []
|
546 |
|
@@ -595,7 +611,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
595 |
elif self.training:
|
596 |
for i in range(bs):
|
597 |
# dummy audio_embeddings
|
598 |
-
input_embeddings
|
599 |
|
600 |
return input_embeddings
|
601 |
|
@@ -751,7 +767,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
751 |
input_ids=None,
|
752 |
pixel_values=None,
|
753 |
tgt_sizes=None,
|
754 |
-
audio_features=
|
755 |
audio_feature_lens=None,
|
756 |
image_bound=None,
|
757 |
audio_bounds=None,
|
@@ -2982,7 +2998,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|
2982 |
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
2983 |
|
2984 |
position_ids = torch.tensor(
|
2985 |
-
[past_key_values[0][0].shape[2]
|
2986 |
).unsqueeze(0)
|
2987 |
|
2988 |
cache_position = position_ids.clone()
|
|
|
377 |
else:
|
378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
379 |
|
380 |
+
new_vllm_embedding = vllm_embedding.clone()
|
381 |
+
|
382 |
vision_hidden_states = [
|
383 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
384 |
]
|
|
|
394 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
395 |
).to(vllm_embedding.device)
|
396 |
|
397 |
+
new_vllm_embedding[i] = cur_vllm_emb.scatter(
|
398 |
0,
|
399 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
400 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
401 |
)
|
402 |
+
|
403 |
elif self.training:
|
404 |
+
new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
|
405 |
|
406 |
+
return new_vllm_embedding, vision_hidden_states
|
407 |
|
408 |
def get_audio_embedding_streaming(self, data):
|
409 |
r"""
|
|
|
466 |
else:
|
467 |
return []
|
468 |
|
469 |
+
def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
|
470 |
r"""
|
471 |
Extract full audio embeddings with optional chunk-based attention.
|
472 |
|
|
|
484 |
Returns:
|
485 |
List[List[torch.Tensor]]: audio embeddings
|
486 |
"""
|
487 |
+
dtype = self.apm.embed_positions.weight.dtype
|
488 |
+
device = self.apm.embed_positions.weight.device
|
489 |
|
490 |
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
491 |
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
|
|
546 |
idx += 1
|
547 |
final_audio_embeds.append(target_audio_embeds)
|
548 |
return final_audio_embeds
|
549 |
+
elif self.training and dummy:
|
550 |
+
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
551 |
+
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
552 |
+
|
553 |
+
audio_embeds = self.audio_projection_layer(audio_states)
|
554 |
+
|
555 |
+
audio_embeds = audio_embeds.transpose(1, 2)
|
556 |
+
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
557 |
+
audio_embeds = audio_embeds.transpose(1, 2)
|
558 |
+
return [audio_embeds]
|
559 |
+
|
560 |
else:
|
561 |
return []
|
562 |
|
|
|
611 |
elif self.training:
|
612 |
for i in range(bs):
|
613 |
# dummy audio_embeddings
|
614 |
+
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
|
615 |
|
616 |
return input_embeddings
|
617 |
|
|
|
767 |
input_ids=None,
|
768 |
pixel_values=None,
|
769 |
tgt_sizes=None,
|
770 |
+
audio_features=[],
|
771 |
audio_feature_lens=None,
|
772 |
image_bound=None,
|
773 |
audio_bounds=None,
|
|
|
2998 |
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
2999 |
|
3000 |
position_ids = torch.tensor(
|
3001 |
+
[past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
|
3002 |
).unsqueeze(0)
|
3003 |
|
3004 |
cache_position = position_ids.clone()
|