nielsr HF staff commited on
Commit
9855500
1 Parent(s): 7321f63

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +48 -0
README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - tapex
5
+ - table-question-answering
6
+ license: apache-2.0
7
+ datasets:
8
+ - wtq
9
+ inference: false
10
+ ---
11
+
12
+ TAPEX-large model fine-tuned on WTQ. This model was proposed in [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou. Original repo can be found [here](https://github.com/microsoft/Table-Pretraining).
13
+
14
+ To load it and run inference, you can do the following:
15
+
16
+ ```
17
+ from transformers import BartTokenizer, BartForSequenceClassification
18
+ import pandas as pd
19
+
20
+ tokenizer = BartTokenizer.from_pretrained("nielsr/tapex-large-finetuned-tabfact")
21
+ model = BartForSequenceClassification.from_pretrained("nielsr/tapex-large-finetuned-tabfact")
22
+
23
+ # create table
24
+ data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], 'Number of movies': ["87", "53", "69"]}
25
+ table = pd.DataFrame.from_dict(data)
26
+
27
+ # turn into dict
28
+ table_dict = {"header": list(table.columns), "rows": [list(row.values) for i,row in table.iterrows()]}
29
+
30
+ # turn into format TAPEX expects
31
+ # define the linearizer based on this code: https://github.com/microsoft/Table-Pretraining/blob/main/tapex/processor/table_linearize.py
32
+ linearizer = IndexedRowTableLinearize()
33
+ linear_table = linearizer.process_table(table_dict)
34
+
35
+ # add sentence
36
+ sentence = "George Clooney has 69 movies"
37
+ joint_input = sentence + " " + linear_table
38
+
39
+ # encode
40
+ encoding = tokenizer(joint_input, return_tensors="pt")
41
+
42
+ # forward pass
43
+ outputs = model(**encoding)
44
+
45
+ # print prediction
46
+ logits = outputs.logits
47
+ print(logits.argmax(-1))
48
+ ```