sooh-j commited on
Commit
6ff1d6b
1 Parent(s): 8364853

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -11,7 +11,7 @@ import base64
11
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
  self.model_base = "Salesforce/blip2-opt-2.7b"
16
  self.model_name = "sooh-j/blip2-vizwizqa"
17
  # self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
@@ -101,7 +101,7 @@ class EndpointHandler():
101
  # )
102
 
103
  with torch.no_grad():
104
- out = self.model.generate(**processed)
105
 
106
  result = {}
107
  text_output = self.processor.decode(out[0], skip_special_tokens=True)
 
11
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
  self.model_base = "Salesforce/blip2-opt-2.7b"
16
  self.model_name = "sooh-j/blip2-vizwizqa"
17
  # self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
 
101
  # )
102
 
103
  with torch.no_grad():
104
+ out = self.model.generate(**processed).to(self.device)
105
 
106
  result = {}
107
  text_output = self.processor.decode(out[0], skip_special_tokens=True)