Crystalcareai commited on
Commit
def2825
·
verified ·
1 Parent(s): b11137f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +68 -50
modeling_quiet.py CHANGED
@@ -54,21 +54,61 @@ _CONFIG_FOR_DOC = "QuietConfig"
54
 
55
 
56
  def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
 
57
  bsz, tgt_len = input_shape
58
 
 
 
 
 
59
  if attention_mask is not None:
60
- if attention_mask.dim() == 3:
61
- # Expanding from [batch_size, 1, tgt_len] to [batch_size, 1, tgt_len, tgt_len]
62
- attention_mask = attention_mask.expand(bsz, 1, tgt_len, tgt_len)
 
 
 
 
 
 
 
 
63
  elif attention_mask.dim() == 2:
64
- # Expanding from [batch_size, tgt_len] to [batch_size, 1, tgt_len, tgt_len]
65
- attention_mask = attention_mask.unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len)
 
 
 
 
 
 
66
  else:
67
- raise ValueError(f"Unexpected attention mask shape: {attention_mask.shape}, expected 2 or 3 dimensions.")
 
 
 
 
68
 
69
- attention_mask = (1.0 - attention_mask) * -10000.0 # Masking operation for softmax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- return attention_mask
72
 
73
 
74
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -1056,58 +1096,36 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1056
  # Apply the language model head to get the final logits
1057
  logits = self.lm_head(mixed_hidden_states)
1058
  return logits
1059
-
1060
  @torch.no_grad()
1061
  def generate(
1062
  self,
1063
- input_ids=None,
1064
- attention_mask=None,
1065
- max_new_tokens=None,
1066
- min_length=None,
1067
- do_sample=None,
1068
- early_stopping=None,
1069
- num_beams=None,
1070
- temperature=1.0,
1071
- top_k=None,
1072
- top_p=None,
1073
- repetition_penalty=None,
1074
- bad_words_ids=None,
1075
- bos_token_id=None,
1076
- pad_token_id=None,
1077
- eos_token_id=None,
1078
- length_penalty=None,
1079
- no_repeat_ngram_size=None,
1080
- num_return_sequences=None,
1081
- decoder_start_token_id=None,
1082
- use_cache=None,
1083
- num_beam_groups=None,
1084
- diversity_penalty=None,
1085
- prefix_allowed_tokens_fn=None,
1086
- output_attentions=None,
1087
- output_hidden_states=None,
1088
- output_scores=None,
1089
- return_dict_in_generate=None,
1090
- forced_bos_token_id=None,
1091
- forced_eos_token_id=None,
1092
- remove_invalid_values=None,
1093
- synced_gpus=None,
1094
- **model_kwargs,
1095
  ):
1096
- # Prepare the generation process with customized settings
1097
- model_inputs = self.prepare_inputs_for_generation(
1098
- input_ids, past_key_values=None, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
1099
- )
1100
-
1101
- from .generate import custom_generate
1102
- return custom_generate(
 
 
 
1103
  self,
1104
- input_ids=input_ids,
1105
  attention_mask=attention_mask,
1106
  max_new_tokens=max_new_tokens,
1107
  temperature=temperature,
1108
- **model_kwargs
1109
  )
1110
 
 
 
1111
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1112
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1113
  def forward(
 
54
 
55
 
56
  def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
57
+ # Compute the attention mask correctly
58
  bsz, tgt_len = input_shape
59
 
60
+ # Create a 4D attention mask from a 2D tensor mask.
61
+ # The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
62
+ # The values are either 0 or 1, where 0 means padding and 1 means non-padding.
63
+ combined_attention_mask = None
64
  if attention_mask is not None:
65
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
66
+ # In this case, we can just use it directly.
67
+ if attention_mask.dim() == 4:
68
+ combined_attention_mask = attention_mask
69
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
70
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
71
+ elif attention_mask.dim() == 3:
72
+ expanded_attn_mask = attention_mask[:, None, :, :]
73
+ combined_attention_mask = expanded_attn_mask
74
+ # What if attention_mask is not None and has a shape of (batch_size, tgt_len)
75
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
76
  elif attention_mask.dim() == 2:
77
+ # Provided a padding mask of dimensions [batch_size, seq_length]
78
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
79
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
80
+ if past_key_values_length > 0:
81
+ attention_mask = attention_mask.to(dtype=torch.long)
82
+ attention_mask = attention_mask[:, past_key_values_length:]
83
+ expanded_attn_mask = attention_mask[:, None, None, :]
84
+ combined_attention_mask = expanded_attn_mask
85
  else:
86
+ raise ValueError(
87
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
88
+ input_shape, attention_mask.shape
89
+ )
90
+ )
91
 
92
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
93
+ # masked positions, this operation will create a tensor which is 0.0 for
94
+ # positions we want to attend and -10000.0 for masked positions.
95
+ # Since we are adding it to the raw scores before the softmax, this is
96
+ # effectively the same as removing these entirely.
97
+ if combined_attention_mask is not None:
98
+ # Ensure the attention mask values are within a reasonable range
99
+ combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
100
+
101
+ # Convert the attention mask to bfloat16
102
+ combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
103
+
104
+ # Normalize the attention mask values to be between 0 and 1
105
+ combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
106
+ else:
107
+ combined_attention_mask = torch.zeros(
108
+ (bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
109
+ )
110
 
111
+ return combined_attention_mask
112
 
113
 
114
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
1096
  # Apply the language model head to get the final logits
1097
  logits = self.lm_head(mixed_hidden_states)
1098
  return logits
1099
+
1100
  @torch.no_grad()
1101
  def generate(
1102
  self,
1103
+ input_ids: torch.LongTensor = torch.LongTensor(),
1104
+ attention_mask: Optional[torch.Tensor] = None,
1105
+ max_new_tokens: Optional[int] = None,
1106
+ temperature: float = 1.1,
1107
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1108
  ):
1109
+ if isinstance(input_ids, str):
1110
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1111
+
1112
+ if attention_mask is None:
1113
+ # Create a default attention mask if not provided
1114
+ attention_mask = torch.ones_like(input_ids)
1115
+
1116
+ from .generate import generate
1117
+
1118
+ output = generate(
1119
  self,
1120
+ input_ids,
1121
  attention_mask=attention_mask,
1122
  max_new_tokens=max_new_tokens,
1123
  temperature=temperature,
1124
+ **kwargs,
1125
  )
1126
 
1127
+ return output.sequences
1128
+
1129
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1130
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1131
  def forward(