Update modeling_bunny_minicpm.py
Browse files
modeling_bunny_minicpm.py
CHANGED
@@ -210,11 +210,17 @@ class BunnyMetaForCausalLM(ABC):
|
|
210 |
if labels is None:
|
211 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
212 |
|
|
|
|
|
213 |
# remove the padding using attention_mask -- TODO: double check
|
214 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
215 |
zip(input_ids, attention_mask)]
|
216 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
217 |
|
|
|
|
|
|
|
|
|
218 |
new_input_embeds = []
|
219 |
new_labels = []
|
220 |
cur_image_idx = 0
|
|
|
210 |
if labels is None:
|
211 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
212 |
|
213 |
+
input_ids_temp = input_ids # points to the actual input_ids tensor
|
214 |
+
|
215 |
# remove the padding using attention_mask -- TODO: double check
|
216 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
217 |
zip(input_ids, attention_mask)]
|
218 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
219 |
|
220 |
+
# -- TODO: better implementation?
|
221 |
+
# replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
|
222 |
+
input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
|
223 |
+
|
224 |
new_input_embeds = []
|
225 |
new_labels = []
|
226 |
cur_image_idx = 0
|