VictorSanh
commited on
Commit
·
7515eca
1
Parent(s):
c8028be
big update
Browse files- modeling_img2html.py +10 -12
modeling_img2html.py
CHANGED
@@ -162,7 +162,7 @@ def expand_inputs_for_generation(
|
|
162 |
input_ids = input_ids.index_select(0, expanded_return_idx)
|
163 |
model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
|
164 |
model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
|
165 |
-
model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
|
166 |
|
167 |
if "token_type_ids" in model_kwargs:
|
168 |
token_type_ids = model_kwargs["token_type_ids"]
|
@@ -180,9 +180,7 @@ def expand_inputs_for_generation(
|
|
180 |
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
|
181 |
|
182 |
elif model_kwargs["image_hidden_states"] is not None:
|
183 |
-
model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(
|
184 |
-
0, expanded_return_idx
|
185 |
-
)
|
186 |
|
187 |
return input_ids, model_kwargs
|
188 |
|
@@ -205,10 +203,10 @@ def update_model_kwargs_for_generation(outputs, model_kwargs):
|
|
205 |
model_kwargs["attention_mask"] = torch.cat(
|
206 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
207 |
)
|
208 |
-
if "image_attention_mask" in model_kwargs:
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
|
213 |
# Get the precomputed image_hidden_states
|
214 |
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
@@ -236,7 +234,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|
236 |
|
237 |
pixel_values = kwargs.get("pixel_values", None)
|
238 |
image_hidden_states = kwargs.get("image_hidden_states", None)
|
239 |
-
image_attention_mask = kwargs.get("image_attention_mask", None)
|
240 |
|
241 |
return {
|
242 |
"input_ids": input_ids,
|
@@ -247,7 +245,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|
247 |
"token_type_ids": token_type_ids,
|
248 |
"pixel_values": pixel_values,
|
249 |
"image_hidden_states": image_hidden_states,
|
250 |
-
"image_attention_mask": image_attention_mask,
|
251 |
}
|
252 |
|
253 |
|
@@ -1373,7 +1371,6 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
1373 |
input_ids: torch.LongTensor = None,
|
1374 |
inputs_embeds: Optional[torch.Tensor] = None,
|
1375 |
image_hidden_states: Optional[torch.Tensor] = None,
|
1376 |
-
num_images: Optional[int] = None,
|
1377 |
):
|
1378 |
"""
|
1379 |
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
|
@@ -1496,6 +1493,8 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
1496 |
|
1497 |
if self.config.use_resampler:
|
1498 |
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
|
|
|
|
1499 |
|
1500 |
if past_key_values is None:
|
1501 |
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
@@ -1504,7 +1503,6 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
1504 |
input_ids=input_ids,
|
1505 |
inputs_embeds=inputs_embeds,
|
1506 |
image_hidden_states=image_hidden_states,
|
1507 |
-
num_images=num_images,
|
1508 |
)
|
1509 |
inputs_embeds = new_inp["inputs_embeds"]
|
1510 |
|
|
|
162 |
input_ids = input_ids.index_select(0, expanded_return_idx)
|
163 |
model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
|
164 |
model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
|
165 |
+
# model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
|
166 |
|
167 |
if "token_type_ids" in model_kwargs:
|
168 |
token_type_ids = model_kwargs["token_type_ids"]
|
|
|
180 |
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
|
181 |
|
182 |
elif model_kwargs["image_hidden_states"] is not None:
|
183 |
+
model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(0, expanded_return_idx)
|
|
|
|
|
184 |
|
185 |
return input_ids, model_kwargs
|
186 |
|
|
|
203 |
model_kwargs["attention_mask"] = torch.cat(
|
204 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
205 |
)
|
206 |
+
# if "image_attention_mask" in model_kwargs:
|
207 |
+
# image_attention_mask = model_kwargs["image_attention_mask"]
|
208 |
+
# last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
|
209 |
+
# model_kwargs["image_attention_mask"] = last_mask
|
210 |
|
211 |
# Get the precomputed image_hidden_states
|
212 |
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
|
|
234 |
|
235 |
pixel_values = kwargs.get("pixel_values", None)
|
236 |
image_hidden_states = kwargs.get("image_hidden_states", None)
|
237 |
+
# image_attention_mask = kwargs.get("image_attention_mask", None)
|
238 |
|
239 |
return {
|
240 |
"input_ids": input_ids,
|
|
|
245 |
"token_type_ids": token_type_ids,
|
246 |
"pixel_values": pixel_values,
|
247 |
"image_hidden_states": image_hidden_states,
|
248 |
+
# "image_attention_mask": image_attention_mask,
|
249 |
}
|
250 |
|
251 |
|
|
|
1371 |
input_ids: torch.LongTensor = None,
|
1372 |
inputs_embeds: Optional[torch.Tensor] = None,
|
1373 |
image_hidden_states: Optional[torch.Tensor] = None,
|
|
|
1374 |
):
|
1375 |
"""
|
1376 |
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
|
|
|
1493 |
|
1494 |
if self.config.use_resampler:
|
1495 |
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
1496 |
+
elif image_hidden_states is not None:
|
1497 |
+
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
1498 |
|
1499 |
if past_key_values is None:
|
1500 |
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
|
1503 |
input_ids=input_ids,
|
1504 |
inputs_embeds=inputs_embeds,
|
1505 |
image_hidden_states=image_hidden_states,
|
|
|
1506 |
)
|
1507 |
inputs_embeds = new_inp["inputs_embeds"]
|
1508 |
|