qnguyen3 commited on
Commit
f93b9e6
1 Parent(s): 21ed169

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +2 -4
modeling_llava_qwen2.py CHANGED
@@ -661,13 +661,11 @@ class LlavaMetaForCausalLM(ABC):
661
  def get_vision_tower(self):
662
  return self.get_model().get_vision_tower()
663
 
664
- @spaces.GPU
665
  def encode_images(self, images):
666
- image_features = self.get_model().get_vision_tower().to("cuda:0")(images)
667
- image_features = self.get_model().mm_projector.to("cuda:0")(image_features)
668
  return image_features
669
 
670
- @spaces.GPU
671
  def prepare_inputs_labels_for_multimodal(
672
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images
673
  ):
 
661
  def get_vision_tower(self):
662
  return self.get_model().get_vision_tower()
663
 
 
664
  def encode_images(self, images):
665
+ image_features = self.get_model().get_vision_tower().to(decice="cuda:0", dtype=torch.float)(images)
666
+ image_features = self.get_model().mm_projector.to(device="cuda:0", dtype=torch.float)(image_features)
667
  return image_features
668
 
 
669
  def prepare_inputs_labels_for_multimodal(
670
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images
671
  ):