Update modeling_cogvlm.py

#19
by nielsr HF staff - opened
Files changed (2) hide show
  1. modeling_cogvlm.py +170 -2
  2. visual.py +43 -5
modeling_cogvlm.py CHANGED
@@ -117,7 +117,8 @@ def attention_fn(
117
  attention_mask: "torch.tensor(B, H, L, HD)",
118
  *,
119
  scaling_attention_score: bool = True,
120
- attention_dropout: nn.Module = None
 
121
  ):
122
  attention_mask_bool = (attention_mask == 0)
123
  is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
@@ -126,6 +127,10 @@ def attention_fn(
126
  warnings.warn("It's recommended to use torch2.0 or higher.")
127
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
128
  dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
 
 
 
 
129
  return torch.nn.functional.scaled_dot_product_attention(
130
  query_layer, key_layer, value_layer,
131
  attn_mask=None,
@@ -225,6 +230,7 @@ class VisionExpertAttention(nn.Module):
225
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
  output_attentions: bool = False,
227
  use_cache: bool = False,
 
228
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
229
  bsz, q_len, _ = hidden_states.size()
230
  vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
@@ -240,6 +246,36 @@ class VisionExpertAttention(nn.Module):
240
  key_states = self._transpose_for_scores(key_states) # B, H, L, HD
241
  value_states = self._transpose_for_scores(value_states) # B, H, L, HD
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  kv_seq_len = key_states.shape[-2]
244
  if past_key_value is not None:
245
  kv_seq_len += past_key_value[0].shape[-2]
@@ -252,9 +288,31 @@ class VisionExpertAttention(nn.Module):
252
 
253
  past_key_value = (key_states, value_states) if use_cache else None
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  context_layer = attention_fn(
256
  query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
257
- scaling_attention_score=True, attention_dropout=None)
 
 
 
 
 
258
  if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
259
  raise ValueError(
260
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
@@ -290,11 +348,18 @@ class CogVLMDecoderLayer(nn.Module):
290
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
291
  output_attentions: Optional[bool] = False,
292
  use_cache: Optional[bool] = False,
 
293
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
294
  residual = hidden_states
295
 
 
 
 
296
  hidden_states = self.input_layernorm(hidden_states)
297
 
 
 
 
298
  # Self Attention
299
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
300
  hidden_states=hidden_states,
@@ -304,7 +369,12 @@ class CogVLMDecoderLayer(nn.Module):
304
  past_key_value=past_key_value,
305
  output_attentions=output_attentions,
306
  use_cache=use_cache,
 
307
  )
 
 
 
 
308
  hidden_states = residual + hidden_states
309
 
310
  # Fully Connected
@@ -413,6 +483,7 @@ class CogVLMModel(CogVLMPreTrainedModel):
413
  output_attentions: Optional[bool] = None,
414
  output_hidden_states: Optional[bool] = None,
415
  return_dict: Optional[bool] = None,
 
416
  ) -> Union[Tuple, BaseModelOutputWithPast]:
417
  """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
418
 
@@ -425,10 +496,44 @@ class CogVLMModel(CogVLMPreTrainedModel):
425
  assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
426
  assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
427
  inputs_embeds = self.embed_tokens(input_ids)
 
428
  images_features = self.encode_images(images)
429
  images_features = rearrange(images_features, 'b n d -> (b n) d')
430
  images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
 
 
 
432
  else: # single-modality
433
  if token_type_ids is None:
434
  token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
@@ -450,6 +555,7 @@ class CogVLMModel(CogVLMPreTrainedModel):
450
  output_attentions=output_attentions,
451
  output_hidden_states=output_hidden_states,
452
  return_dict=return_dict,
 
453
  )
454
 
455
  def llm_forward(
@@ -464,6 +570,7 @@ class CogVLMModel(CogVLMPreTrainedModel):
464
  output_attentions: Optional[bool] = None,
465
  output_hidden_states: Optional[bool] = None,
466
  return_dict: Optional[bool] = None,
 
467
  ) -> Union[Tuple, BaseModelOutputWithPast]:
468
  """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
469
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -513,6 +620,48 @@ class CogVLMModel(CogVLMPreTrainedModel):
513
 
514
  hidden_states = inputs_embeds
515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  # decoder layers
517
  all_hidden_states = () if output_hidden_states else None
518
  all_self_attns = () if output_attentions else None
@@ -531,9 +680,21 @@ class CogVLMModel(CogVLMPreTrainedModel):
531
  past_key_value=past_key_value,
532
  output_attentions=output_attentions,
533
  use_cache=use_cache,
 
534
  )
535
  hidden_states = layer_outputs[0]
536
 
 
 
 
 
 
 
 
 
 
 
 
537
  if use_cache:
538
  next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
539
 
@@ -542,6 +703,10 @@ class CogVLMModel(CogVLMPreTrainedModel):
542
 
543
  hidden_states = self.norm(hidden_states)
544
 
 
 
 
 
545
  # add hidden states from the last decoder layer
546
  if output_hidden_states:
547
  all_hidden_states += (hidden_states,)
@@ -649,6 +814,7 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
649
  output_hidden_states: Optional[bool] = None,
650
  return_dict: Optional[bool] = None,
651
  labels: Optional[torch.LongTensor] = None,
 
652
  ) -> Union[Tuple, CausalLMOutputWithPast]:
653
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
654
  output_hidden_states = (
@@ -669,6 +835,7 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
669
  output_attentions=output_attentions,
670
  output_hidden_states=output_hidden_states,
671
  return_dict=return_dict,
 
672
  )
673
 
674
  hidden_states = outputs[0]
@@ -745,6 +912,7 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
745
  model_kwargs: Dict[str, Any],
746
  is_encoder_decoder: bool = False,
747
  standardize_cache_format: bool = False,
 
748
  ) -> Dict[str, Any]:
749
  # update past_key_values
750
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
 
117
  attention_mask: "torch.tensor(B, H, L, HD)",
118
  *,
119
  scaling_attention_score: bool = True,
120
+ attention_dropout: nn.Module = None,
121
+ print_values: bool = False,
122
  ):
123
  attention_mask_bool = (attention_mask == 0)
124
  is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
 
127
  warnings.warn("It's recommended to use torch2.0 or higher.")
128
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
129
  dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
130
+
131
+ if print_values:
132
+ print("Is_causal:", not is_full)
133
+
134
  return torch.nn.functional.scaled_dot_product_attention(
135
  query_layer, key_layer, value_layer,
136
  attn_mask=None,
 
230
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
231
  output_attentions: bool = False,
232
  use_cache: bool = False,
233
+ print_values: bool = False,
234
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
235
  bsz, q_len, _ = hidden_states.size()
236
  vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
 
246
  key_states = self._transpose_for_scores(key_states) # B, H, L, HD
247
  value_states = self._transpose_for_scores(value_states) # B, H, L, HD
248
 
249
+ # if print_values:
250
+
251
+ # torch.save(query_states, "query_states.pt")
252
+ # torch.save(key_states, "key_states.pt")
253
+ # torch.save(value_states, "value_states.pt")
254
+
255
+ # from huggingface_hub import HfApi
256
+
257
+ # api = HfApi()
258
+ # api.upload_file(
259
+ # path_or_fileobj="query_states.pt",
260
+ # path_in_repo="query_states.pt",
261
+ # repo_id="nielsr/test-cogvlm",
262
+ # repo_type="dataset",
263
+ # )
264
+ # api = HfApi()
265
+ # api.upload_file(
266
+ # path_or_fileobj="key_states.pt",
267
+ # path_in_repo="key_states.pt",
268
+ # repo_id="nielsr/test-cogvlm",
269
+ # repo_type="dataset",
270
+ # )
271
+ # api = HfApi()
272
+ # api.upload_file(
273
+ # path_or_fileobj="value_states.pt",
274
+ # path_in_repo="value_states.pt",
275
+ # repo_id="nielsr/test-cogvlm",
276
+ # repo_type="dataset",
277
+ # )
278
+
279
  kv_seq_len = key_states.shape[-2]
280
  if past_key_value is not None:
281
  kv_seq_len += past_key_value[0].shape[-2]
 
288
 
289
  past_key_value = (key_states, value_states) if use_cache else None
290
 
291
+ if print_values:
292
+ print("Shape of query_states:", query_states.shape)
293
+ print("Last values of query_states:", query_states[0,0,-3:,-3:])
294
+ print("Mean of query_states:", query_states.mean())
295
+
296
+ print("Shape of key_states:", key_states.shape)
297
+ print("Last values of key_states:", key_states[0,0,-3:,-3:])
298
+ print("Mean of key_states:", key_states.mean())
299
+
300
+ print("Shape of value_states:", value_states.shape)
301
+ print("First values of value_states:", value_states[0,0,-3:,-3:])
302
+ print("Mean of value_states:", value_states.mean())
303
+
304
+ print("Shape of the attention_mask:", attention_mask.shape)
305
+ print("Mean of the attention_mask:", attention_mask.float().mean())
306
+ print("Is_full:", (attention_mask > 0).all())
307
+
308
  context_layer = attention_fn(
309
  query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
310
+ scaling_attention_score=True, attention_dropout=None, print_values=print_values)
311
+
312
+ if print_values:
313
+ print("Shape of context_layer:", context_layer.shape)
314
+ print("First values of context_layer:", context_layer[0,0,:3,:3])
315
+
316
  if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
317
  raise ValueError(
318
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 
348
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
349
  output_attentions: Optional[bool] = False,
350
  use_cache: Optional[bool] = False,
351
+ print_values = False,
352
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
353
  residual = hidden_states
354
 
355
+ # if print_values:
356
+ # print("Hidden states before RMS norm:", hidden_states[0, :3, :3])
357
+
358
  hidden_states = self.input_layernorm(hidden_states)
359
 
360
+ # if print_values:
361
+ # print("Hidden states after RMS norm, before self attention:", hidden_states[0,:3,:3])
362
+
363
  # Self Attention
364
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
365
  hidden_states=hidden_states,
 
369
  past_key_value=past_key_value,
370
  output_attentions=output_attentions,
371
  use_cache=use_cache,
372
+ print_values=print_values,
373
  )
374
+
375
+ # if print_values:
376
+ # print("Hidden states after self attention:", hidden_states[0,:3,:3])
377
+
378
  hidden_states = residual + hidden_states
379
 
380
  # Fully Connected
 
483
  output_attentions: Optional[bool] = None,
484
  output_hidden_states: Optional[bool] = None,
485
  return_dict: Optional[bool] = None,
486
+ step: int = None,
487
  ) -> Union[Tuple, BaseModelOutputWithPast]:
488
  """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
489
 
 
496
  assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
497
  assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
498
  inputs_embeds = self.embed_tokens(input_ids)
499
+
500
  images_features = self.encode_images(images)
501
  images_features = rearrange(images_features, 'b n d -> (b n) d')
502
  images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
503
+
504
+ # from huggingface_hub import HfApi
505
+
506
+ # torch.save(images_features, "images_features.pt")
507
+ # torch.save(inputs_embeds, "inputs_embeds.pt")
508
+ # torch.save(token_type_ids, "token_type_ids.pt")
509
+
510
+ # api = HfApi()
511
+ # api.upload_file(
512
+ # path_or_fileobj="images_features.pt",
513
+ # path_in_repo="images_features.pt",
514
+ # repo_id="nielsr/test-cogvlm",
515
+ # repo_type="dataset",
516
+ # )
517
+ # api.upload_file(
518
+ # path_or_fileobj="inputs_embeds.pt",
519
+ # path_in_repo="inputs_embeds.pt",
520
+ # repo_id="nielsr/test-cogvlm",
521
+ # repo_type="dataset",
522
+ # )
523
+ # api.upload_file(
524
+ # path_or_fileobj="token_type_ids.pt",
525
+ # path_in_repo="token_type_ids.pt",
526
+ # repo_id="nielsr/test-cogvlm",
527
+ # repo_type="dataset",
528
+ # )
529
+
530
+ # print("First values of text embeddings:", inputs_embeds[0, :3, :3])
531
+ # print("First values of images_features:", images_features[0, :3])
532
+
533
  inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
534
+
535
+ # print("First values of inputs_embeds after index_put:", inputs_embeds[0, :3, :3])
536
+
537
  else: # single-modality
538
  if token_type_ids is None:
539
  token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
 
555
  output_attentions=output_attentions,
556
  output_hidden_states=output_hidden_states,
557
  return_dict=return_dict,
558
+ step=step,
559
  )
560
 
561
  def llm_forward(
 
570
  output_attentions: Optional[bool] = None,
571
  output_hidden_states: Optional[bool] = None,
572
  return_dict: Optional[bool] = None,
573
+ step: int = None,
574
  ) -> Union[Tuple, BaseModelOutputWithPast]:
575
  """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
576
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
620
 
621
  hidden_states = inputs_embeds
622
 
623
+ if step == 1:
624
+ torch.save(hidden_states, "hidden_states_step_1.pt")
625
+ torch.save(attention_mask, "attention_mask_step_1.pt")
626
+ torch.save(token_type_ids, "token_type_ids_step_1.pt")
627
+ torch.save(position_ids, "position_ids_step_1.pt")
628
+ torch.save(past_key_values, "past_key_value_step_1.pt")
629
+
630
+ from huggingface_hub import HfApi
631
+
632
+ api = HfApi()
633
+
634
+ api.upload_file(
635
+ path_or_fileobj="hidden_states_step_1.pt",
636
+ path_in_repo="hidden_states_step_1.pt",
637
+ repo_id="nielsr/test-cogvlm",
638
+ repo_type="dataset",
639
+ )
640
+ api.upload_file(
641
+ path_or_fileobj="attention_mask_step_1.pt",
642
+ path_in_repo="attention_mask_step_1.pt",
643
+ repo_id="nielsr/test-cogvlm",
644
+ repo_type="dataset",
645
+ )
646
+ api.upload_file(
647
+ path_or_fileobj="token_type_ids_step_1.pt",
648
+ path_in_repo="token_type_ids_step_1.pt",
649
+ repo_id="nielsr/test-cogvlm",
650
+ repo_type="dataset",
651
+ )
652
+ api.upload_file(
653
+ path_or_fileobj="position_ids_step_1.pt",
654
+ path_in_repo="position_ids_step_1.pt",
655
+ repo_id="nielsr/test-cogvlm",
656
+ repo_type="dataset",
657
+ )
658
+ api.upload_file(
659
+ path_or_fileobj="past_key_value_step_1.pt",
660
+ path_in_repo="past_key_value_step_1.pt",
661
+ repo_id="nielsr/test-cogvlm",
662
+ repo_type="dataset",
663
+ )
664
+
665
  # decoder layers
666
  all_hidden_states = () if output_hidden_states else None
667
  all_self_attns = () if output_attentions else None
 
680
  past_key_value=past_key_value,
681
  output_attentions=output_attentions,
682
  use_cache=use_cache,
683
+ print_values=idx==0 and step==1,
684
  )
685
  hidden_states = layer_outputs[0]
686
 
687
+ # if idx == 0:
688
+ # torch.save(hidden_states, "hidden_states_after_layer_0.pt")
689
+
690
+ # api = HfApi()
691
+ # api.upload_file(
692
+ # path_or_fileobj="hidden_states_after_layer_0.pt",
693
+ # path_in_repo="hidden_states_after_layer_0.pt",
694
+ # repo_id="nielsr/test-cogvlm",
695
+ # repo_type="dataset",
696
+ # )
697
+
698
  if use_cache:
699
  next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
700
 
 
703
 
704
  hidden_states = self.norm(hidden_states)
705
 
706
+ if step == 1:
707
+ print("Shape of hidden states:", hidden_states.shape)
708
+ print("First values of hidden states:", hidden_states[0,:3,:3])
709
+
710
  # add hidden states from the last decoder layer
711
  if output_hidden_states:
712
  all_hidden_states += (hidden_states,)
 
814
  output_hidden_states: Optional[bool] = None,
815
  return_dict: Optional[bool] = None,
816
  labels: Optional[torch.LongTensor] = None,
817
+ step: int = None,
818
  ) -> Union[Tuple, CausalLMOutputWithPast]:
819
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
820
  output_hidden_states = (
 
835
  output_attentions=output_attentions,
836
  output_hidden_states=output_hidden_states,
837
  return_dict=return_dict,
838
+ step=step,
839
  )
840
 
841
  hidden_states = outputs[0]
 
912
  model_kwargs: Dict[str, Any],
913
  is_encoder_decoder: bool = False,
914
  standardize_cache_format: bool = False,
915
+ model_inputs: Optional[Dict[str, Any]] = None,
916
  ) -> Dict[str, Any]:
917
  # update past_key_values
918
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
visual.py CHANGED
@@ -31,7 +31,7 @@ class Attention(nn.Module):
31
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
  self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
 
34
- def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
  B, L, _ = x.shape
36
  qkv = self.query_key_value(x)
37
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
@@ -40,6 +40,7 @@ class Attention(nn.Module):
40
  out = xops.memory_efficient_attention(
41
  q, k, v, scale=self.scale,
42
  )
 
43
  output = self.dense(out.view(B, L, -1))
44
  output = self.output_dropout(output)
45
  return output
@@ -74,9 +75,18 @@ class TransformerLayer(nn.Module):
74
  self.mlp = MLP(config)
75
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
 
77
- def forward(self, hidden_states):
78
  attention_input = hidden_states
79
- attention_output = self.input_layernorm(self.attention(attention_input))
 
 
 
 
 
 
 
 
 
80
  hidden_states = attention_input + attention_output
81
  mlp_input = hidden_states
82
  mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
@@ -90,8 +100,36 @@ class Transformer(nn.Module):
90
  self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
 
92
  def forward(self, hidden_states):
93
- for layer_module in self.layers:
94
- hidden_states = layer_module(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return hidden_states
96
 
97
 
 
31
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
  self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
 
34
+ def forward(self, x: "tensor(B, L, D)", print_values=False) -> "tensor(B, L, D)":
35
  B, L, _ = x.shape
36
  qkv = self.query_key_value(x)
37
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
 
40
  out = xops.memory_efficient_attention(
41
  q, k, v, scale=self.scale,
42
  )
43
+
44
  output = self.dense(out.view(B, L, -1))
45
  output = self.output_dropout(output)
46
  return output
 
75
  self.mlp = MLP(config)
76
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
77
 
78
+ def forward(self, hidden_states, print_values=False):
79
  attention_input = hidden_states
80
+
81
+ if print_values:
82
+ print("Hidden states before attention:", attention_input[0, :3, :3])
83
+
84
+ attention_output = self.attention(attention_input, print_values=print_values)
85
+
86
+ if print_values:
87
+ print("Hidden states after attention:", attention_output[0, :3, :3])
88
+
89
+ attention_output = self.input_layernorm(attention_output)
90
  hidden_states = attention_input + attention_output
91
  mlp_input = hidden_states
92
  mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
 
100
  self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
101
 
102
  def forward(self, hidden_states):
103
+
104
+ print("Shape of hidden states before CLIP:", hidden_states.shape)
105
+ # torch.save(hidden_states, "hidden_states_before_clip.pt")
106
+
107
+ # from huggingface_hub import HfApi
108
+
109
+ # api = HfApi()
110
+ # api.upload_file(
111
+ # path_or_fileobj="hidden_states_before_clip.pt",
112
+ # path_in_repo="hidden_states_before_clip.pt",
113
+ # repo_id="nielsr/test-cogvlm",
114
+ # repo_type="dataset",
115
+ # )
116
+
117
+ for idx, layer_module in enumerate(self.layers):
118
+ hidden_states = layer_module(hidden_states, print_values=idx==0)
119
+
120
+ print("Shape of hidden states after CLIP:", hidden_states.shape)
121
+ # torch.save(hidden_states, "hidden_states_after_clip.pt")
122
+
123
+ # from huggingface_hub import HfApi
124
+
125
+ # api = HfApi()
126
+ # api.upload_file(
127
+ # path_or_fileobj="hidden_states_after_clip.pt",
128
+ # path_in_repo="hidden_states_after_clip.pt",
129
+ # repo_id="nielsr/test-cogvlm",
130
+ # repo_type="dataset",
131
+ # )
132
+
133
  return hidden_states
134
 
135