qnguyen3 commited on
Commit
21ed169
1 Parent(s): 17d73ee

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +2 -2
modeling_llava_qwen2.py CHANGED
@@ -663,8 +663,8 @@ class LlavaMetaForCausalLM(ABC):
663
 
664
  @spaces.GPU
665
  def encode_images(self, images):
666
- image_features = self.get_model().get_vision_tower()(images)
667
- image_features = self.get_model().mm_projector(image_features)
668
  return image_features
669
 
670
  @spaces.GPU
 
663
 
664
  @spaces.GPU
665
  def encode_images(self, images):
666
+ image_features = self.get_model().get_vision_tower().to("cuda:0")(images)
667
+ image_features = self.get_model().mm_projector.to("cuda:0")(image_features)
668
  return image_features
669
 
670
  @spaces.GPU