ShaoXia's picture
初始化代码
069157b
raw
history blame
1.53 kB
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from PIL import Image
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16)
if device == "cuda":
model_deplot = model_deplot.to(0)
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
def add_markup(table):
try:
parts = [p.strip() for p in table.splitlines(keepends=False)]
if parts[0].startswith('TITLE'):
result = f"Title: {parts[0].split(' | ')[1].strip()}\n"
rows = parts[1:]
else:
result = ''
rows = parts
prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)]
return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows))
except:
# just use the raw table if parsing fails
return table
def process_image(image):
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:",
return_tensors="pt").to(torch.bfloat16)
if device == "cuda":
inputs = inputs.to(0)
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
return table
if __name__ == "__main__":
im = Image.open(r"meat-image.png")
process_image(im)