File size: 7,132 Bytes
ae2c89d
fdac8dd
ae2c89d
800ade2
f9e72d0
 
 
800ade2
f9e72d0
ae2c89d
f67adcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac0b0ba
 
 
f67adcd
 
 
 
 
 
 
 
 
 
 
 
 
8fc7477
f9e72d0
8fc7477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000d979
8fc7477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000d979
ac0b0ba
f67adcd
 
 
 
e9f0c2c
 
ae2c89d
 
 
f67adcd
4a35755
6dc09e2
f67adcd
 
6dc09e2
7eb1431
6dc09e2
ae2c89d
6dc09e2
6663c9a
ae2c89d
 
 
 
6dc09e2
 
 
 
 
 
 
 
 
 
c619767
 
ae2c89d
 
925ce8d
c619767
 
 
0542a87
ae2c89d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import requests
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import os 
import torch
from peft import PeftModel
import transformers

## CoT prompts

def _add_markup(table):
    parts = [p.strip() for p in table.splitlines(keepends=False)]
    if parts[0].startswith('TITLE'):
        result = f"Title: {parts[0].split(' | ')[1].strip()}\n"
        rows = parts[1:]
    else:
        result = ''
        rows = parts
    prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)]
    return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows))


_TABLE = """Year | Democrats | Republicans | Independents
2004 | 68.1% | 45.0% | 53.0%
2006 | 58.0% | 42.0% | 53.0%
2007 | 59.0% | 38.0% | 45.0%
2009 | 72.0% | 49.0% | 60.0%
2011 | 71.0% | 51.2% | 58.0%
2012 | 70.0% | 48.0% | 53.0%
2013 | 72.0% | 41.0% | 60.0%"""

_INSTRUCTION = 'Read the table below to answer the following questions.'


_TEMPLATE = f"""First read an example then the complete question for the second table.
------------
{_INSTRUCTION}
{_add_markup(_TABLE)}
Q: In which year republicans have the lowest favor rate?
A: Let's find the column of republicans. Then let's extract the favor rates, they [45.0, 42.0, 38.0, 49.0, 51.2, 48.0, 41.0]. The smallest number is 38.0, that's Row 3.  Row 3 is year 2007. The answer is 2007.
Q: What is the sum of Democrats' favor rates of 2004, 2012, and 2013?
A: Let's find the rows of years 2004, 2012, and 2013. We find Row 1, 6, 7. The favor dates of Demoncrats on that 3 rows are 68.1, 70.0, and 72.0. 68.1+70.0+72=210.1. The answer is 210.1.
Q: By how many points do Independents surpass Republicans in the year of 2011?
A: Let's find the row with year = 2011. We find Row 5. We extract Independents and Republicans' numbers. They are 58.0 and 51.2. 58.0-51.2=6.8. The answer is 6.8.
Q: Which group has the overall worst performance?
A: Let's sample a couple of years. In Row 1, year 2004, we find Republicans having the lowest favor rate 45.0 (since 45.0<68.1, 45.0<53.0). In year 2006, Row 2, we find Republicans having the lowest favor rate 42.0 (42.0<58.0, 42.0<53.0). The trend continues to other years. The answer is Republicans.
Q: Which party has the second highest favor rates in 2007?
A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on Row 3: [59.0, 38.0, 45.0]. 45.0 is the second highest. 45.0 is the number of Independents. The answer is Independents.
{_INSTRUCTION}"""


## alpaca-lora

assert (
    "LlamaTokenizer" in transformers._import_structure["models.llama"]
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

BASE_MODEL = "decapoda-research/llama-7b-hf"
LORA_WEIGHTS = "tloen/alpaca-lora-7b"

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:
    pass

if device == "cuda":
    model = LlamaForCausalLM.from_pretrained(
        BASE_MODEL,
        load_in_8bit=False,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model = PeftModel.from_pretrained(
        model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
    )
elif device == "mps":
    model = LlamaForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map={"": device},
        torch_dtype=torch.float16,
    )
    model = PeftModel.from_pretrained(
        model,
        LORA_WEIGHTS,
        device_map={"": device},
        torch_dtype=torch.float16,
    )
else:
    model = LlamaForCausalLM.from_pretrained(
        BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
    )
    model = PeftModel.from_pretrained(
        model,
        LORA_WEIGHTS,
        device_map={"": device},
    )


if device != "cpu":
    model.half()
model.eval()
if torch.__version__ >= "2":
    model = torch.compile(model)


def evaluate(
    table,
    question,
    input=None,
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    max_new_tokens=128,
    **kwargs,
):
    prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        **kwargs,
    )
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    #return output.split("A:")[-1].strip()
    return output




model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")

def process_document(image, question):
    # image = Image.open(image)
    inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt")
    predictions = model_deplot.generate(**inputs, max_new_tokens=512)
    table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")

    # send prompt+table to LLM
    res = evaluate(table, question)
    #return res + "\n\n" + res.split("A:")[-1]
    return [table, res.split("A:")[-1]]
 
description = "Demo for DePlot+LLM for QA and summarisation. To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"

demo = gr.Interface(
    fn=process_document,
    inputs=["image", "text"],
    outputs=[
        gr.inputs.Textbox(
            lines=8,
            label="Intermediate Table",
        ),
        gr.inputs.Textbox(
            lines=5,
            label="Output",
        )
    ],
    title="DePlot+LLM (Multimodal chain-of-thought reasoning on plots)",
    description=description,
    article=article,
    enable_queue=True,
    examples=[["deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step."],
              ["deplot_case_study_m1.png", "Summarise the chart for me please."]
              ["deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step."],
              ["deplot_case_study_4.png", "How many papers are submitted in 2020?"],
              ["deplot_case_study_x2.png", "Summarise the chart for me please."]],
    cache_examples=False)

demo.launch()