File size: 1,973 Bytes
bbaf732 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer
import pandas as pd
import os
import nltk
import string
import math
import sys
import argparse
import random
"""# Modelo T5
Importamos o modelo preadestrado
"""
"""# Corpus
#J# Leemos nuestro dataset.
"""
test_split = pd.read_csv('./test-dataset.csv', encoding="latin-1")
test_split= test_split.reset_index()
def generate(text):
print("Tokenizing sequence...")
x = tokenizer(text, return_tensors='pt', padding=True).to(model.device)
print("Generating description...")
out = model.generate(**x, do_sample=False, num_beams=10, max_new_tokens = 50)
return tokenizer.decode(out[0], skip_special_tokens=True)
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_table", type=int, default=280, required=False, help="Specify data ID")
parser.add_argument("-o", "--output", type=str, default="./", required=False, help="Specify output path")
args = parser.parse_args()
data_id = args.input_table
output_path = args.output
if data_id not in range(0, 569):
sys.exit("ERROR: ID must be in the range [0,568] (testing IDs)")
#J# cargamos el modelo pre-entrenado que queramos, junto con su tokenizador
print("Loading model...")
model = T5ForConditionalGeneration.from_pretrained('data2text_gl_v1')
tokenizer = T5Tokenizer.from_pretrained("data2text_gl_v1")
print("Loading data... (dataset-id: " + str(test_split.id[int(data_id)]) + ")")
data = test_split.table[int(data_id)]
gold = test_split.caption[int(data_id)]
generation = generate(data)
img_id = str(test_split.id[int(data_id)])
pattern = "- Test ID: {} (DB id: {})\n- Data table: {}\n- Generated text: {}\n- Gold text: {}"
print(pattern.format(data_id, img_id, data[0:100] + "... </table>", generation, gold))
with open(output_path + "generated_"+ str(data_id) + ".txt", "w") as output_file:
output_file.write(pattern.format(data_id, img_id, data, generation, gold))
output_file.close()
|