Spaces:
Build error
Build error
from typing import List, Tuple, Set, Dict | |
from huggingface_hub import hf_hub_download | |
import re | |
from PIL import Image | |
from transformers import NougatProcessor, VisionEncoderDecoderModel | |
from datasets import load_dataset | |
import torch | |
from doctrfiles import DetectionResult | |
# Numpy image type | |
import numpy.typing as npt | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
def run_nougat(inputs: List[Tuple[int, ImageType]])-> List[DetectionResult]: | |
processor = NougatProcessor.from_pretrained("facebook/nougat-base") | |
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
detection_results =[] | |
for index, np_img in inputs: | |
image = Image.fromarray(np_img) | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
# generate transcription (here we only generate 30 tokens) | |
outputs = model.generate( | |
pixel_values.to(device), | |
min_length=1, | |
max_new_tokens=30, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
) | |
sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
sequence = processor.post_process_generation(sequence, fix_markdown=False) | |
# note: we're using repr here such for the sake of printing the \n characters, feel free to just print the sequence | |
text = sequence | |
detection_results.append(DetectionResult(score=1, text=text, index=index)) | |
return detection_results | |