|
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer |
|
import pandas as pd |
|
import os |
|
import nltk |
|
import string |
|
import math |
|
import sys |
|
import argparse |
|
import random |
|
|
|
|
|
"""# Modelo T5 |
|
Importamos o modelo preadestrado |
|
""" |
|
|
|
"""# Corpus |
|
#J# Leemos nuestro dataset. |
|
""" |
|
|
|
test_split = pd.read_csv('./test-dataset.csv', encoding="latin-1") |
|
test_split= test_split.reset_index() |
|
|
|
def generate(text): |
|
print("Tokenizing sequence...") |
|
x = tokenizer(text, return_tensors='pt', padding=True).to(model.device) |
|
print("Generating description...") |
|
out = model.generate(**x, do_sample=False, num_beams=10, max_new_tokens = 50) |
|
return tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-i", "--input_table", type=int, default=280, required=False, help="Specify data ID") |
|
parser.add_argument("-o", "--output", type=str, default="./", required=False, help="Specify output path") |
|
args = parser.parse_args() |
|
|
|
data_id = args.input_table |
|
output_path = args.output |
|
|
|
if data_id not in range(0, 569): |
|
sys.exit("ERROR: ID must be in the range [0,568] (testing IDs)") |
|
|
|
|
|
print("Loading model...") |
|
model = T5ForConditionalGeneration.from_pretrained('data2text_gl_v1') |
|
tokenizer = T5Tokenizer.from_pretrained("data2text_gl_v1") |
|
|
|
print("Loading data... (dataset-id: " + str(test_split.id[int(data_id)]) + ")") |
|
|
|
data = test_split.table[int(data_id)] |
|
gold = test_split.caption[int(data_id)] |
|
generation = generate(data) |
|
img_id = str(test_split.id[int(data_id)]) |
|
|
|
pattern = "- Test ID: {} (DB id: {})\n- Data table: {}\n- Generated text: {}\n- Gold text: {}" |
|
|
|
print(pattern.format(data_id, img_id, data[0:100] + "... </table>", generation, gold)) |
|
|
|
with open(output_path + "generated_"+ str(data_id) + ".txt", "w") as output_file: |
|
output_file.write(pattern.format(data_id, img_id, data, generation, gold)) |
|
output_file.close() |
|
|