Update handler.py
Browse files- handler.py +2 -2
handler.py
CHANGED
@@ -11,7 +11,7 @@ import base64
|
|
11 |
|
12 |
class EndpointHandler():
|
13 |
def __init__(self, path=""):
|
14 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
self.model_base = "Salesforce/blip2-opt-2.7b"
|
16 |
self.model_name = "sooh-j/blip2-vizwizqa"
|
17 |
# self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
|
@@ -101,7 +101,7 @@ class EndpointHandler():
|
|
101 |
# )
|
102 |
|
103 |
with torch.no_grad():
|
104 |
-
out = self.model.generate(**processed)
|
105 |
|
106 |
result = {}
|
107 |
text_output = self.processor.decode(out[0], skip_special_tokens=True)
|
|
|
11 |
|
12 |
class EndpointHandler():
|
13 |
def __init__(self, path=""):
|
14 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
15 |
self.model_base = "Salesforce/blip2-opt-2.7b"
|
16 |
self.model_name = "sooh-j/blip2-vizwizqa"
|
17 |
# self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
|
|
|
101 |
# )
|
102 |
|
103 |
with torch.no_grad():
|
104 |
+
out = self.model.generate(**processed).to(self.device)
|
105 |
|
106 |
result = {}
|
107 |
text_output = self.processor.decode(out[0], skip_special_tokens=True)
|