Fix modeling code (typos/bugs)
#6
by
Xenova
HF staff
- opened
- modeling_florence2.py +9 -3
modeling_florence2.py
CHANGED
@@ -2240,6 +2240,10 @@ class Florence2Seq2SeqLMOutput(ModelOutput):
|
|
2240 |
decoding.
|
2241 |
|
2242 |
Args:
|
|
|
|
|
|
|
|
|
2243 |
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
2244 |
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
2245 |
|
@@ -2288,7 +2292,8 @@ class Florence2Seq2SeqLMOutput(ModelOutput):
|
|
2288 |
|
2289 |
image_hidden_states of the model produced by the vision encoder
|
2290 |
"""
|
2291 |
-
|
|
|
2292 |
last_hidden_state: torch.FloatTensor = None
|
2293 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
2294 |
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
@@ -2297,6 +2302,7 @@ class Florence2Seq2SeqLMOutput(ModelOutput):
|
|
2297 |
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
2298 |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
2299 |
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
|
|
2300 |
|
2301 |
|
2302 |
FLORENCE2_START_DOCSTRING = r"""
|
@@ -2527,7 +2533,6 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2527 |
def __init__(self, config: Florence2Config):
|
2528 |
super().__init__(config)
|
2529 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
2530 |
-
del config.vision_config.model_type
|
2531 |
self.vision_tower = DaViT.from_config(config=config.vision_config)
|
2532 |
# remove unused layers
|
2533 |
del self.vision_tower.head
|
@@ -2731,7 +2736,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2731 |
image_features = self._encode_image(pixel_values)
|
2732 |
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2733 |
|
2734 |
-
|
|
|
2735 |
outputs = self.language_model(
|
2736 |
attention_mask=attention_mask,
|
2737 |
labels=labels,
|
|
|
2240 |
decoding.
|
2241 |
|
2242 |
Args:
|
2243 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
2244 |
+
Language modeling loss.
|
2245 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
2246 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
2247 |
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
2248 |
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
2249 |
|
|
|
2292 |
|
2293 |
image_hidden_states of the model produced by the vision encoder
|
2294 |
"""
|
2295 |
+
loss: Optional[torch.FloatTensor] = None
|
2296 |
+
logits: torch.FloatTensor = None
|
2297 |
last_hidden_state: torch.FloatTensor = None
|
2298 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
2299 |
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
|
|
2302 |
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
2303 |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
2304 |
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
2305 |
+
image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
2306 |
|
2307 |
|
2308 |
FLORENCE2_START_DOCSTRING = r"""
|
|
|
2533 |
def __init__(self, config: Florence2Config):
|
2534 |
super().__init__(config)
|
2535 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
|
|
2536 |
self.vision_tower = DaViT.from_config(config=config.vision_config)
|
2537 |
# remove unused layers
|
2538 |
del self.vision_tower.head
|
|
|
2736 |
image_features = self._encode_image(pixel_values)
|
2737 |
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2738 |
|
2739 |
+
if inputs_embeds is not None:
|
2740 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
2741 |
outputs = self.language_model(
|
2742 |
attention_mask=attention_mask,
|
2743 |
labels=labels,
|