File size: 1,532 Bytes
069157b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from PIL import Image

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16)
if device == "cuda":
    model_deplot = model_deplot.to(0)
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")



def add_markup(table):
    try:
        parts = [p.strip() for p in table.splitlines(keepends=False)]
        if parts[0].startswith('TITLE'):
            result = f"Title: {parts[0].split(' | ')[1].strip()}\n"
            rows = parts[1:]
        else:
            result = ''
            rows = parts
        prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)]
        return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows))
    except:
        # just use the raw table if parsing fails
        return table

def process_image(image):
    inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:",
                              return_tensors="pt").to(torch.bfloat16)
    if device == "cuda":
        inputs = inputs.to(0)
    predictions = model_deplot.generate(**inputs, max_new_tokens=512)
    table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
    return table


if __name__ == "__main__":
    im = Image.open(r"meat-image.png")
    process_image(im)