Spaces:
Runtime error
Runtime error
import gradio | |
import transformers | |
import types | |
checkpoint_path = "checkpoint" | |
examples_path = "examples" | |
MODEL = types.SimpleNamespace() | |
MODEL.donut_processor = transformers.DonutProcessor.from_pretrained(checkpoint_path) | |
MODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(checkpoint_path) | |
MODEL.tokenizer = MODEL.donut_processor.tokenizer | |
def generate_token_strings(images, skip_special_tokens=True) -> list[str]: | |
decoder_output = MODEL.encoder_decoder.generate( | |
images, | |
max_length=MODEL.encoder_decoder.config.decoder.max_length, | |
eos_token_id=MODEL.tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
) | |
return MODEL.tokenizer.batch_decode( | |
decoder_output.sequences, skip_special_tokens=skip_special_tokens | |
) | |
def predict_string(image) -> str: | |
image = MODEL.donut_processor( | |
image, random_padding=False, return_tensors="pt" | |
).pixel_values | |
string = generate_token_strings(image)[0] | |
return string | |
interface = gradio.Interface( | |
title = "Making graphs accessible", | |
description = "Generate textual representation of a graph\n" | |
"https://www.kaggle.com/competitions/benetech-making-graphs-accessible", | |
fn=predict_string, | |
inputs="image", | |
outputs="text", | |
examples=examples_path, | |
) | |
interface.launch() |