t5m_pocet / init2.py
makprgmax
3 git ignore
969f9a3
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
# Укажите директорию с моделью и токенизатором
model_directory = './'
# Загрузка токенизатора
tokenizer = MT5Tokenizer.from_pretrained(model_directory, sp_model_kwargs={"model_file": f"{model_directory}/spiece.model"})
# Загрузка модели
model = MT5ForConditionalGeneration.from_pretrained(model_directory, local_files_only=True)
# Входной текст
# context = "Вы бот PocketBook Reader. Вы говорите только о PocketBook Reader, и вам не разрешается сравнивать PocketBook с другими брендами, такими как Kindle, Tolino, Kobo и т. д."
context = "You are a PocketBook Reader bot. You only talk about PocketBook Reader and you are not allowed to compare PocketBook with other brands like Kindle, Tolino, Kobo, etc."
#question = "Как включить PocketBook?"
question = "How to turn it on PocketBook?"
input_text = f"question: {question} context: {context}"
# Токенизация
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# Генерация ответа с указанием параметров для контроля длины
# outputs = model.generate(input_ids, max_new_tokens=50)
outputs = model.generate(input_ids, max_new_tokens=50, num_beams=5, temperature=0.7, top_k=50)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Вывод
print(answer)
print(answer)