finalf0 commited on
Commit
ee9d0bc
·
1 Parent(s): fe724e9

Replace the inplace operation

Browse files
Files changed (1) hide show
  1. modeling_minicpmo.py +23 -7
modeling_minicpmo.py CHANGED
@@ -377,6 +377,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
377
  else:
378
  vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
379
 
 
 
380
  vision_hidden_states = [
381
  i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
382
  ]
@@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
392
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
393
  ).to(vllm_embedding.device)
394
 
395
- cur_vllm_emb.scatter_(
396
  0,
397
  image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
398
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
399
  )
 
400
  elif self.training:
401
- cur_vllm_emb += cur_vs_hs[0].mean() * 0
402
 
403
- return vllm_embedding, vision_hidden_states
404
 
405
  def get_audio_embedding_streaming(self, data):
406
  r"""
@@ -463,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
463
  else:
464
  return []
465
 
466
- def get_audio_embedding(self, data, chunk_length=-1):
467
  r"""
468
  Extract full audio embeddings with optional chunk-based attention.
469
 
@@ -481,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
481
  Returns:
482
  List[List[torch.Tensor]]: audio embeddings
483
  """
 
 
484
 
485
  wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
486
  audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
@@ -541,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
541
  idx += 1
542
  final_audio_embeds.append(target_audio_embeds)
543
  return final_audio_embeds
 
 
 
 
 
 
 
 
 
 
 
544
  else:
545
  return []
546
 
@@ -595,7 +611,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
595
  elif self.training:
596
  for i in range(bs):
597
  # dummy audio_embeddings
598
- input_embeddings += audio_embeddings[0].mean() * 0
599
 
600
  return input_embeddings
601
 
@@ -751,7 +767,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
751
  input_ids=None,
752
  pixel_values=None,
753
  tgt_sizes=None,
754
- audio_features=None,
755
  audio_feature_lens=None,
756
  image_bound=None,
757
  audio_bounds=None,
@@ -2982,7 +2998,7 @@ class ConditionalChatTTS(PreTrainedModel):
2982
  inputs_embeds = torch.stack(code_emb, 3).sum(3)
2983
 
2984
  position_ids = torch.tensor(
2985
- [past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device
2986
  ).unsqueeze(0)
2987
 
2988
  cache_position = position_ids.clone()
 
377
  else:
378
  vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
379
 
380
+ new_vllm_embedding = vllm_embedding.clone()
381
+
382
  vision_hidden_states = [
383
  i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
384
  ]
 
394
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
395
  ).to(vllm_embedding.device)
396
 
397
+ new_vllm_embedding[i] = cur_vllm_emb.scatter(
398
  0,
399
  image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
400
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
401
  )
402
+
403
  elif self.training:
404
+ new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
405
 
406
+ return new_vllm_embedding, vision_hidden_states
407
 
408
  def get_audio_embedding_streaming(self, data):
409
  r"""
 
466
  else:
467
  return []
468
 
469
+ def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
470
  r"""
471
  Extract full audio embeddings with optional chunk-based attention.
472
 
 
484
  Returns:
485
  List[List[torch.Tensor]]: audio embeddings
486
  """
487
+ dtype = self.apm.embed_positions.weight.dtype
488
+ device = self.apm.embed_positions.weight.device
489
 
490
  wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
491
  audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
 
546
  idx += 1
547
  final_audio_embeds.append(target_audio_embeds)
548
  return final_audio_embeds
549
+ elif self.training and dummy:
550
+ dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
551
+ audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
552
+
553
+ audio_embeds = self.audio_projection_layer(audio_states)
554
+
555
+ audio_embeds = audio_embeds.transpose(1, 2)
556
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
557
+ audio_embeds = audio_embeds.transpose(1, 2)
558
+ return [audio_embeds]
559
+
560
  else:
561
  return []
562
 
 
611
  elif self.training:
612
  for i in range(bs):
613
  # dummy audio_embeddings
614
+ input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
615
 
616
  return input_embeddings
617
 
 
767
  input_ids=None,
768
  pixel_values=None,
769
  tgt_sizes=None,
770
+ audio_features=[],
771
  audio_feature_lens=None,
772
  image_bound=None,
773
  audio_bounds=None,
 
2998
  inputs_embeds = torch.stack(code_emb, 3).sum(3)
2999
 
3000
  position_ids = torch.tensor(
3001
+ [past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
3002
  ).unsqueeze(0)
3003
 
3004
  cache_position = position_ids.clone()