Zhibinhong commited on
Commit
ee61ccb
1 Parent(s): f856131

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -26
handler.py CHANGED
@@ -8,39 +8,20 @@ import json
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
- # Preload all the elements you are going to need at inference.
12
- # pseudo:
13
- # self.model= load_model(path)
14
  self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
  self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
16
- #是否能够只传handler一个文件?答案是可以的。。。
17
  def __call__(self, data):
18
- """
19
- data args:
20
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
21
- kwargs
22
- Return:
23
- A :obj:`list` | `dict`: will be serialized and returned
24
- """
25
- inputs=data.pop("inputs",data)
26
- inputs=base64.b64decode(inputs)
27
- raw_images = Image.open(BytesIO(inputs))
28
- # raw_image = Image.open(image_path).convert('RGB')
29
-
30
- # # conditional image captioning
31
- # text = "a photography of"
32
- # inputs = self.processor(raw_image, text, return_tensors="pt").to("cuda", torch.float16)
33
-
34
- # out = self.model.generate(**inputs)
35
- # print(self.processor.decode(out[0], skip_special_tokens=True))
36
- # >>> a photography of a woman and her dog
37
-
38
- # unconditional image captioning
39
  inputs = self.processor(raw_images, return_tensors="pt").to("cuda")
40
 
41
  out = self.model.generate(**inputs)
42
  # print(self.processor.decode(out[0], skip_special_tokens=True))
43
- return json.dumps({'text':self.processor.decode(out[0], skip_special_tokens=True)})
44
 
45
  if __name__=="__main__":
46
  my_handler=EndpointHandler(path='.')
 
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
 
 
 
11
  self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
  self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
13
+
14
  def __call__(self, data):
15
+ info=data['inputs']
16
+ img=info.pop('image',data)
17
+ image_bytes=base64.b64decode(img)
18
+ raw_images = Image.open(BytesIO(image_bytes))
19
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  inputs = self.processor(raw_images, return_tensors="pt").to("cuda")
21
 
22
  out = self.model.generate(**inputs)
23
  # print(self.processor.decode(out[0], skip_special_tokens=True))
24
+ return {'text':self.processor.decode(out[0], skip_special_tokens=True)}
25
 
26
  if __name__=="__main__":
27
  my_handler=EndpointHandler(path='.')