Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +3 -3
modeling_llava_qwen2.py
CHANGED
@@ -662,14 +662,14 @@ class LlavaMetaForCausalLM(ABC):
|
|
662 |
return self.get_model().get_vision_tower()
|
663 |
|
664 |
def encode_images(self, images):
|
665 |
-
image_features = self.get_model().get_vision_tower()(images)
|
666 |
-
image_features = self.get_model().mm_projector(image_features
|
667 |
return image_features
|
668 |
|
669 |
def prepare_inputs_labels_for_multimodal(
|
670 |
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
671 |
):
|
672 |
-
vision_tower = self.get_vision_tower()
|
673 |
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
674 |
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
|
675 |
1] == 1:
|
|
|
662 |
return self.get_model().get_vision_tower()
|
663 |
|
664 |
def encode_images(self, images):
|
665 |
+
image_features = self.get_model().get_vision_tower().cuda()(images)
|
666 |
+
image_features = self.get_model().mm_projector(image_features)
|
667 |
return image_features
|
668 |
|
669 |
def prepare_inputs_labels_for_multimodal(
|
670 |
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
671 |
):
|
672 |
+
vision_tower = self.get_vision_tower().cuda()
|
673 |
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
674 |
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
|
675 |
1] == 1:
|