gti-coco-en / pipeline.py
TeamAlerito's picture
Upload pipeline.py
bfe8399
raw
history blame
2.15 kB
import os
from typing import Dict, List, Any
from PIL import Image
import jax
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel, VisionEncoderDecoderModel
import torch
class PreTrainedPipeline():
def __init__(self, path=""):
model_dir = path
# self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
self.model = VisionEncoderDecoderModel.from_pretrained(model_dir)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
max_length = 16
num_beams = 4
# self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "return_dict_in_generate": True, "output_scores": True}
self.model.to("cpu")
self.model.eval()
# @jax.jit
def _generate(pixel_values):
with torch.no_grad():
outputs = self.model.generate(pixel_values, **self.gen_kwargs)
output_ids = outputs.sequences
sequences_scores = outputs.sequences_scores
return output_ids, sequences_scores
self.generate = _generate
# compile the model
image_path = os.path.join(path, 'val_000000039769.jpg')
image = Image.open(image_path)
self(image)
image.close()
def __call__(self, inputs: "Image.Image") -> List[str]:
"""
Args:
Return:
"""
# pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
pixel_values = self.feature_extractor(images=inputs, return_tensors="pt").pixel_values
output_ids, sequences_scores = self.generate(pixel_values)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
preds = [{"label": preds[0], "score": float(sequences_scores[0])}]
return preds