nougat-base / handler.py
jmbrito's picture
add handler.py for HF Dedicated Inference
fc62d33 verified
raw
history blame
3.58 kB
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
import torch.cuda
import io
import base64
from PIL import Image
from typing import Dict, Any
from collections import defaultdict
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
)
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
class EndpointHandler():
def __init__(self, path="facebook/nougat-base"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = NougatProcessor.from_pretrained(path)
self.model = VisionEncoderDecoderModel.from_pretrained(path)
self.model = self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt
and generation parameters.
"""
# Get inputs
input = data.pop("inputs", None)
parameters = data.pop("parameters", None)
fix_markdown = data.pop("fix_markdown", None)
if input is None:
raise ValueError("Missing image.")
# autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
binary_data = base64.b64decode(input)
image = Image.open(io.BytesIO(binary_data))
pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
outputs = self.model.generate(
pixel_values=pixel_values.to(self.model.device),
min_length=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
**parameters,
)
generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
return {"generated_text": prediction}