|
from typing import Dict, List, Any |
|
from transformers import Pipeline |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import json |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda") |
|
|
|
def __call__(self, data): |
|
info=data['inputs'] |
|
img=info.pop('image',data) |
|
image_bytes=base64.b64decode(img) |
|
raw_images = Image.open(BytesIO(image_bytes)) |
|
|
|
inputs = self.processor(raw_images, return_tensors="pt").to("cuda") |
|
|
|
out = self.model.generate(**inputs) |
|
|
|
return {'text':self.processor.decode(out[0], skip_special_tokens=True)} |
|
|
|
if __name__=="__main__": |
|
my_handler=EndpointHandler(path='.') |
|
test_payload={"inputs": "/home/ubuntu/guoling/1.png"} |
|
test_result=my_handler(test_payload) |
|
print(test_result) |
|
|