Titobsala's picture
Upload 19 files
0ab2514 verified
raw
history blame
1.46 kB
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
def test_model(input_text, model, tokenizer, max_length=128):
# Tokenize a entrada
inputs = tokenizer(input_text, return_tensors="pt", max_length=max_length, truncation=True, padding="max_length")
# Mover para GPU se disponível
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Gerar a saída com parâmetros ajustados
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95,
)
# Decodificar a saída
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded_output
# Carregar o modelo e tokenizer salvos
model_path = "./meu_modelo_treinado"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Exemplo de uso
input_texts = ["workers", "employee", "labor"]
for input_text in input_texts:
output = test_model(input_text, model, tokenizer)
print(f"Input: {input_text}")
print(f"Output: {output}")
print("-" * 30)
# Imprimir informações sobre o modelo
print(f"Tamanho do vocabulário: {len(tokenizer)}")
print(f"Número de parâmetros do modelo: {model.num_parameters()}")