Spaces:
Runtime error
Runtime error
import torch | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
from PIL import Image | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16) | |
if device == "cuda": | |
model_deplot = model_deplot.to(0) | |
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot") | |
def add_markup(table): | |
try: | |
parts = [p.strip() for p in table.splitlines(keepends=False)] | |
if parts[0].startswith('TITLE'): | |
result = f"Title: {parts[0].split(' | ')[1].strip()}\n" | |
rows = parts[1:] | |
else: | |
result = '' | |
rows = parts | |
prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)] | |
return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows)) | |
except: | |
# just use the raw table if parsing fails | |
return table | |
def process_image(image): | |
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", | |
return_tensors="pt").to(torch.bfloat16) | |
if device == "cuda": | |
inputs = inputs.to(0) | |
predictions = model_deplot.generate(**inputs, max_new_tokens=512) | |
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n") | |
return table | |
if __name__ == "__main__": | |
im = Image.open(r"meat-image.png") | |
process_image(im) | |