sooh-j commited on
Commit
50da7fb
1 Parent(s): 90c4e38

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -56
handler.py CHANGED
@@ -8,15 +8,13 @@ import requests
8
  import torch
9
  from io import BytesIO
10
  import base64
11
-
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
  self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
  print("device:",self.device)
16
  self.model_base = "Salesforce/blip2-opt-2.7b"
17
  self.model_name = "sooh-j/blip2-vizwizqa"
18
- # self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
19
-
20
  self.processor = AutoProcessor.from_pretrained(self.model_name)
21
  self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
22
  device_map="auto",
@@ -37,69 +35,22 @@ class EndpointHandler():
37
  # image: await (await fetch('https://placekitten.com/300/300')).blob()
38
  # }
39
  # })
40
- ###################
41
  inputs = data.get("inputs")
42
  imageBase64 = inputs.get("image")
43
- # imageURL = inputs.get("image")
44
  question = inputs.get("question")
45
- # print(imageURL)
46
- # print(text)
47
- # image = Image.open(requests.get(imageBase64, stream=True).raw)
48
- import base64
49
- from PIL import Image
50
- # import matplotlib.pyplot as plt
51
- #try2
52
- # image = Image.open(BytesIO(base64.b64decode(imageBase64)))
53
- #try1
54
- image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
55
- ###################
56
-
57
- ######################################
58
-
59
- # inputs = data.pop("inputs", data)
60
- # parameters = data.pop("parameters", {})
61
- # # if isinstance(inputs, Image.Image):
62
- # # image = [inputs]
63
- # # else:
64
- # # try:
65
- # # imageBase64 = inputs["image"]
66
- # # # image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
67
- # # image = Image.open(BytesIO(base64.b64decode(imageBase64)))
68
-
69
- # # except:
70
- # image_url = inputs['image']
71
- # image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
72
-
73
-
74
- # question = inputs["question"]
75
- ######################################
76
- # data = data.pop("inputs", data)
77
- # data = data.pop("image", image)
78
 
 
79
  # image = Image.open(requests.get(imageBase64, stream=True).raw)
80
- # image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
81
- #### https://huggingface.co/SlowPacer/witron-image-captioning/blob/main/handler.py
82
 
83
- # if isinstance(inputs, Image.Image):
84
- # image = [inputs]
85
- # else:
86
- # inputs = isinstance(inputs, str) and [inputs] or inputs
87
- # image = [Image.open(BytesIO(base64.b64decode(_img))) for _img in inputs]
88
-
89
- # processed_images = self.processor(images=raw_images, return_tensors="pt")
90
- # processed_images["pixel_values"] = processed_images["pixel_values"].to(device)
91
- # processed_images = {**processed_images, **parameters}
92
 
93
- ####
94
-
95
-
96
  prompt = f"Question: {question}, Answer:"
97
  processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
98
 
99
- # answer = self._generate_answer(
100
- # model_path, prompt, image,
101
- # )
102
-
103
  with torch.no_grad():
104
  out = self.model.generate(**processed, max_new_tokens=512).to(self.device)
105
 
@@ -107,4 +58,5 @@ class EndpointHandler():
107
  text_output = self.processor.decode(out[0], skip_special_tokens=True)
108
  result["text_output"] = text_output
109
  score = 0
 
110
  return [{"answer":text_output,"score":score}]
 
8
  import torch
9
  from io import BytesIO
10
  import base64
11
+
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
  self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
  print("device:",self.device)
16
  self.model_base = "Salesforce/blip2-opt-2.7b"
17
  self.model_name = "sooh-j/blip2-vizwizqa"
 
 
18
  self.processor = AutoProcessor.from_pretrained(self.model_name)
19
  self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
20
  device_map="auto",
 
35
  # image: await (await fetch('https://placekitten.com/300/300')).blob()
36
  # }
37
  # })
38
+
39
  inputs = data.get("inputs")
40
  imageBase64 = inputs.get("image")
 
41
  question = inputs.get("question")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # imageURL = inputs.get("image")
44
  # image = Image.open(requests.get(imageBase64, stream=True).raw)
 
 
45
 
46
+ if 'http:' in imageBase64:
47
+ image = Image.open(requests.get(imageBase64, stream=True).raw)
48
+ else:
49
+ image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
 
 
 
 
 
50
 
 
 
 
51
  prompt = f"Question: {question}, Answer:"
52
  processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
53
 
 
 
 
 
54
  with torch.no_grad():
55
  out = self.model.generate(**processed, max_new_tokens=512).to(self.device)
56
 
 
58
  text_output = self.processor.decode(out[0], skip_special_tokens=True)
59
  result["text_output"] = text_output
60
  score = 0
61
+
62
  return [{"answer":text_output,"score":score}]