blip2_test / handler.py
florentgbelidji's picture
added files
3bd45f6
# handler.py
import torch
from transformers import pipeline, AutoProcessor, Blip2ForConditionalGeneration
import os
"""import base64
from io import BytesIO
from PIL import Image"""
# check for GPU
device = 0 if torch.cuda.is_available() else -1
class EndpointHandler():
def __init__(self, path=""):
blip2_proc = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
#blip2 = Blip2ForConditionalGeneration.from_pretrained("sharded", device_map="auto", load_in_8bit=True)
blip2 = Blip2ForConditionalGeneration.from_pretrained(os.path.join(path, "sharded"), device_map="auto", load_in_8bit=True)
#translator = pipeline("translation",model="Helsinki-NLP/opus-mt-en-de",device=device)
def __call__(self, data):
# deserialize incomin request
"""b64_img = data.pop("b64", data)
lang = data.pop("lang", None)
decode = data.pop("decode", None)
#prepare image
im_bytes = base64.b64decode(b64_img) # im_bytes is a binary image
im_file = BytesIO(im_bytes) # convert image to file-like object
image = Image.open(im_file).convert("RGB")
output = {}
inputs = self.blip2_proc(image, return_tensors="pt").to(device, torch.float16)
#nucleus vs beam sampling
if decode == None or decode == "beam":
generated_ids = self.blip2.generate(**inputs, max_new_tokens=20)
prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
#english vs german caption
if lang != None or lang == "de":
translation = self.translator(prediction)
output["beam"] = translation[0]
else:
output["beam"] = prediction
if decode != None or decode == "nucleus":
generated_ids = self.blip2.generate(**inputs, max_new_tokens=20)
prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True,do_sample=True)[0].strip()
#english vs german caption
if lang != None or lang == "de":
translation = self.translator(prediction)
output["nucleus"] = translation[0]
else:
output["nucleus"] = prediction
# postprocess the prediction
return output"""
return 73