Spaces:
Build error
Build error
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering | |
import pandas as pd | |
import re | |
p = re.compile('\d+(\.\d+)?') | |
def load_model_and_tokenizer(): | |
""" | |
Load | |
""" | |
tokenizer = AutoTokenizer.from_pretrained("Meena/table-question-answering-tapas") | |
model = AutoModelForTableQuestionAnswering.from_pretrained("Meena/table-question-answering-tapas") | |
# Return tokenizer and model | |
return tokenizer, model | |
def prepare_inputs(table, queries, tokenizer): | |
""" | |
Convert dictionary into data frame and tokenize inputs given queries. | |
""" | |
table = table.astype('str').head(100) | |
inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt") | |
return table, inputs | |
def generate_predictions(inputs, model, tokenizer): | |
""" | |
Generate predictions for some tokenized input. | |
""" | |
# Generate model results | |
outputs = model(**inputs) | |
# Convert logit outputs into predictions for table cells and aggregation operators | |
predicted_table_cell_coords, predicted_aggregation_operators = tokenizer.convert_logits_to_predictions( | |
inputs, | |
outputs.logits.detach(), | |
outputs.logits_aggregation.detach() | |
) | |
# Return values | |
return predicted_table_cell_coords, predicted_aggregation_operators | |
def postprocess_predictions(predicted_aggregation_operators, predicted_table_cell_coords, table): | |
""" | |
Compute the predicted operation and nicely structure the answers. | |
""" | |
# Process predicted aggregation operators | |
aggregation_operators = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"} | |
aggregation_predictions_string = [aggregation_operators[x] for x in predicted_aggregation_operators] | |
# Process predicted table cell coordinates | |
answers = [] | |
for agg, coordinates in zip(predicted_aggregation_operators, predicted_table_cell_coords): | |
if len(coordinates) == 1: | |
# 1 cell | |
answers.append(table.iat[coordinates[0]]) | |
else: | |
# > 1 cell | |
cell_values = [] | |
for coordinate in coordinates: | |
cell_values.append(table.iat[coordinate]) | |
answers.append(", ".join(cell_values)) | |
# Return values | |
return aggregation_predictions_string, answers | |
def show_answers(queries, answers, aggregation_predictions_string): | |
""" | |
Visualize the postprocessed answers. | |
""" | |
agg = {"NONE": lambda x: x, "SUM" : lambda x: sum(x), "AVERAGE": lambda x: (sum(x) / len(x)), "COUNT": lambda x: len(x)} | |
results = [] | |
for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string): | |
print(query) | |
if predicted_agg == "NONE": | |
print("Predicted answer: " + answer) | |
else: | |
if all([not p.match(val) == None for val in answer.split(', ')]): | |
# print("Predicted answer: " + predicted_agg + "(" + answer + ") = " + str(agg[predicted_agg](list(map(float, answer.split(',')))))) | |
result = str(agg[predicted_agg](list(map(float, answer.split(','))))) | |
elif predicted_agg == "COUNT": | |
# print("Predicted answer: " + predicted_agg + "(" + answer + ") = " + str(agg[predicted_agg](answer.split(',')))) | |
result = str(agg[predicted_agg](answer.split(','))) | |
else: | |
result = predicted_agg + " > " + answer | |
results.append(result) | |
return results | |
def execute_query(query, table): | |
""" | |
Invoke the TAPAS model. | |
""" | |
queries = [query] | |
tokenizer, model = load_model_and_tokenizer() | |
table, inputs = prepare_inputs(table, queries, tokenizer) | |
predicted_table_cell_coords, predicted_aggregation_operators = generate_predictions(inputs, model, tokenizer) | |
aggregation_predictions_string, answers = postprocess_predictions(predicted_aggregation_operators, predicted_table_cell_coords, table) | |
return show_answers(queries, answers, aggregation_predictions_string) | |