|
from typing import Dict, List, Any |
|
from PIL import Image |
|
import requests |
|
import torch |
|
import base64 |
|
from io import BytesIO |
|
from blip import blip_decoder |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large') |
|
self.model.eval() |
|
self.model = self.model.to(device) |
|
|
|
image_size = 384 |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
|
|
|
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
Return: |
|
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
|
- "label": A string representing what the label/class is. There can be multiple labels. |
|
- "score": A score between 0 and 1 describing how confident the model is for this label/class. |
|
""" |
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
image = self.transform(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) |
|
|
|
return caption |
|
|