wiusdy commited on
Commit
24eb62a
1 Parent(s): e727785

removing the cuda field for space HF

Browse files
Files changed (1) hide show
  1. inference.py +4 -4
inference.py CHANGED
@@ -9,8 +9,8 @@ class Inference:
9
  self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
 
11
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
- self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning").to("cuda")
13
- self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning").to("cuda")
14
  logging.set_verbosity_info()
15
  self.logger = logging.get_logger("transformers")
16
 
@@ -33,13 +33,13 @@ class Inference:
33
  return f"{self.vilt_model.config.id2label[idx]}"
34
 
35
  def __inference_saffal_blip(self, image, text):
36
- encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
37
  out = self.blip_model_saffal.generate(**encoding)
38
  generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
39
  return f"{generated_text}"
40
 
41
  def __inference_control_net_blip(self, image, text):
42
- encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
43
  out = self.blip_model_control_net.generate(**encoding)
44
  generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
45
  return f"{generated_text}"
 
9
  self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
 
11
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
+ self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
13
+ self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
14
  logging.set_verbosity_info()
15
  self.logger = logging.get_logger("transformers")
16
 
 
33
  return f"{self.vilt_model.config.id2label[idx]}"
34
 
35
  def __inference_saffal_blip(self, image, text):
36
+ encoding = self.blip_processor(image, text, return_tensors="pt")
37
  out = self.blip_model_saffal.generate(**encoding)
38
  generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
39
  return f"{generated_text}"
40
 
41
  def __inference_control_net_blip(self, image, text):
42
+ encoding = self.blip_processor(image, text, return_tensors="pt")
43
  out = self.blip_model_control_net.generate(**encoding)
44
  generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
45
  return f"{generated_text}"