Add print statements
Browse files- modeling_cogvlm.py +6 -2
modeling_cogvlm.py
CHANGED
@@ -290,12 +290,14 @@ class CogVLMDecoderLayer(nn.Module):
|
|
290 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
291 |
output_attentions: Optional[bool] = False,
|
292 |
use_cache: Optional[bool] = False,
|
|
|
293 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
294 |
residual = hidden_states
|
295 |
|
296 |
hidden_states = self.input_layernorm(hidden_states)
|
297 |
|
298 |
-
|
|
|
299 |
|
300 |
# Self Attention
|
301 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
@@ -308,7 +310,8 @@ class CogVLMDecoderLayer(nn.Module):
|
|
308 |
use_cache=use_cache,
|
309 |
)
|
310 |
|
311 |
-
|
|
|
312 |
|
313 |
hidden_states = residual + hidden_states
|
314 |
|
@@ -539,6 +542,7 @@ class CogVLMModel(CogVLMPreTrainedModel):
|
|
539 |
past_key_value=past_key_value,
|
540 |
output_attentions=output_attentions,
|
541 |
use_cache=use_cache,
|
|
|
542 |
)
|
543 |
hidden_states = layer_outputs[0]
|
544 |
|
|
|
290 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
291 |
output_attentions: Optional[bool] = False,
|
292 |
use_cache: Optional[bool] = False,
|
293 |
+
print_values = False,
|
294 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
295 |
residual = hidden_states
|
296 |
|
297 |
hidden_states = self.input_layernorm(hidden_states)
|
298 |
|
299 |
+
if print_values:
|
300 |
+
print("Hidden states before self attention:", hidden_states[0,:3,:3])
|
301 |
|
302 |
# Self Attention
|
303 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
310 |
use_cache=use_cache,
|
311 |
)
|
312 |
|
313 |
+
if print_values:
|
314 |
+
print("Hidden states after self attention:", hidden_states[0,:3,:3])
|
315 |
|
316 |
hidden_states = residual + hidden_states
|
317 |
|
|
|
542 |
past_key_value=past_key_value,
|
543 |
output_attentions=output_attentions,
|
544 |
use_cache=use_cache,
|
545 |
+
print_values=idx in [0, 1, 2],
|
546 |
)
|
547 |
hidden_states = layer_outputs[0]
|
548 |
|