Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +5 -5
modeling_llava_qwen2.py
CHANGED
@@ -538,13 +538,13 @@ class SigLipVisionTower(nn.Module):
|
|
538 |
if type(images) is list:
|
539 |
image_features = []
|
540 |
for image in images:
|
541 |
-
image_forward_out = self.vision_tower(image.to(device=
|
542 |
output_hidden_states=True)
|
543 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
544 |
assert image_features.shape[-2] == 729
|
545 |
image_features.append(image_feature)
|
546 |
else:
|
547 |
-
image_forward_outs = self.vision_tower(images.to(device=
|
548 |
output_hidden_states=True)
|
549 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
550 |
assert image_features.shape[-2] == 729
|
@@ -553,7 +553,7 @@ class SigLipVisionTower(nn.Module):
|
|
553 |
|
554 |
@property
|
555 |
def dummy_feature(self):
|
556 |
-
return torch.zeros(1, self.hidden_size, device=
|
557 |
|
558 |
@property
|
559 |
def dtype(self):
|
@@ -685,9 +685,9 @@ class LlavaMetaForCausalLM(ABC):
|
|
685 |
image_features = self.encode_images(concat_images)
|
686 |
split_sizes = [image.shape[0] for image in images]
|
687 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
688 |
-
image_features = [x.flatten(0, 1).to(
|
689 |
else:
|
690 |
-
image_features = self.encode_images(images).to(
|
691 |
|
692 |
# Let's just add dummy tensors if they do not exist,
|
693 |
# it is a headache to deal with None all the time.
|
|
|
538 |
if type(images) is list:
|
539 |
image_features = []
|
540 |
for image in images:
|
541 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
542 |
output_hidden_states=True)
|
543 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
544 |
assert image_features.shape[-2] == 729
|
545 |
image_features.append(image_feature)
|
546 |
else:
|
547 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
548 |
output_hidden_states=True)
|
549 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
550 |
assert image_features.shape[-2] == 729
|
|
|
553 |
|
554 |
@property
|
555 |
def dummy_feature(self):
|
556 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
557 |
|
558 |
@property
|
559 |
def dtype(self):
|
|
|
685 |
image_features = self.encode_images(concat_images)
|
686 |
split_sizes = [image.shape[0] for image in images]
|
687 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
688 |
+
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
689 |
else:
|
690 |
+
image_features = self.encode_images(images).to(self.device)
|
691 |
|
692 |
# Let's just add dummy tensors if they do not exist,
|
693 |
# it is a headache to deal with None all the time.
|