# -*- coding: utf-8 -*- | |
"""Gradio_GPT_bot.ipynb | |
import os | |
os.system ('export TRANSFORMERS_CACHE = /my/cache/dir') | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/18CH6wtcr46hWqBqpzieH_oBOmJHecOVl | |
# Imports | |
""" | |
# Commented out IPython magic to ensure Python compatibility. | |
# %%capture | |
# # установка gradio для написания веб интерефейса | |
# # установка transformers для использования языковых моделей с платформы HF | |
#!pip install gradio transformers | |
import random | |
import time | |
from typing import List, Dict, Tuple, Union | |
#from IPython import display | |
import gradio as gr | |
import torch | |
import transformers | |
"""# Tokenizer and Model | |
**Инициализация модели** | |
Страница модели | |
https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2 | |
""" | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
# # инициализация через pipeline | |
# model_name = "ai-forever/rugpt3medium_based_on_gpt2" | |
# pipe = pipeline("text-generation", model=model_name) | |
# sample = pipeline('test test', pad_token_id=generator.tokenizer.eos_token_id) | |
model_name = "ai-forever/rugpt3medium_based_on_gpt2" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(DEVICE) | |
"""Генерация текста""" | |
# Commented out IPython magic to ensure Python compatibility. | |
# %%time | |
# | |
# # токенизация текста в индексы токенов и маски внимания | |
# text_promt = 'меня засосала опасная трясина ' | |
# inputs = tokenizer(text_promt, return_tensors="pt").to(DEVICE) | |
# | |
# # конфиг словарь для генерации текста | |
# gen_config_dict = dict( | |
# do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение) | |
# max_new_tokens=30, # сколько максимум новых токенов надо генерировать | |
# top_k=50, # семплировать только из top_k самых вероятных токенов | |
# top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p | |
# temperature=2.0, # температура для софтмакса | |
# num_beams=3, # параметр алгоритма Beam search | |
# repetition_penalty=2.0, # штраф за то что модель повторяется | |
# pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения | |
# ) | |
# # конфиг для генерации текста из словаря | |
# generation_config = GenerationConfig(**gen_config_dict) | |
# | |
# # генерация текста (индексы токенов) | |
# output = model.generate(**inputs, generation_config=generation_config) | |
# | |
# # сопостовление идексам токенов слов из словаря токенайзера | |
# generated_text = tokenizer.decode(output[0], skip_special_tokens=False) | |
# | |
# # удаление исходного промта из ответа потому что он тоже возвращается | |
# generated_text = generated_text[len(text_promt):] | |
# generated_text | |
"""Функция для генерации""" | |
# функция принимает текстовый запрос и словарь параметров генерации | |
def generate_text(text_promt: str, gen_config_dict: Dict[str, Union[float, int]]) -> str: | |
inputs = tokenizer(text_promt, return_tensors="pt").to(DEVICE) | |
generation_config = GenerationConfig(**gen_config_dict) | |
output = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
generated_text = generated_text[len(text_promt):] | |
return generated_text | |
# конфиг словарь для генерации текста | |
gen_config_dict = dict( | |
do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение) | |
max_new_tokens=20, # сколько максимум новых токенов надо генерировать | |
top_k=50, # семплировать только из top_k самых вероятных токенов | |
top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p | |
temperature=2.0, # температура для софтмакса | |
num_beams=3, # параметр алгоритма Beam search | |
repetition_penalty=2.0, # штраф за то что модель повторяется | |
pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения | |
) | |
text_promt = 'в небесной канцелярии выходной' | |
generated_text = generate_text(text_promt, gen_config_dict) | |
generated_text | |
"""# Gradio App | |
## Новый интерфейс Чат-бота | |
Вариант с системным промтом и разными входными аргументами и настройками | |
""" | |
import gradio as gr | |
# функция будет вызыватся при нажатии на Enter в окошке вовода текста | |
# кроме обычных аргументов - сообщения пользователя и истории - принимает еще параметры для конфига генерации | |
def generate(message, history, *components): | |
# print(system_promt) | |
# обновление словаря новыми агрументами и создание конфига генерации текста | |
gen_config.update(dict(zip(gen_config.keys(), components))) | |
gen_config['top_k'] = int(gen_config['top_k']) | |
gen_config['num_beams'] = int(gen_config['num_beams']) | |
generation_config = GenerationConfig(**gen_config) | |
# добавить системный промт в начало запроса и сгенерировать текст | |
promt = message | |
inputs = tokenizer(promt, return_tensors="pt").to(DEVICE) | |
output = model.generate(**inputs, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
generated_text = generated_text[len(promt):] | |
# имитация набора сообщения чат-ботом (посимвольня генерация через yield в цикле) | |
for i in range(len(generated_text)): | |
time.sleep(0.05) # задержка с которой бот вводит текст | |
yield generated_text[:i+1] | |
# словарь для конфига генерации текста | |
gen_config = dict( | |
do_sample=False, | |
max_length=60, | |
top_k=50, | |
top_p=0.9, | |
temperature=2.0, | |
num_beams=3, | |
repetition_penalty=2.0, | |
) | |
# компоненты настройки конфига генерации текста | |
components = [ | |
gr.Checkbox(label="do_sample", value=gen_config["do_sample"]), | |
gr.Slider(label="max_length", value=gen_config["max_length"], minimum=1, maximum=300, step=10), | |
gr.Number(label="top_k", value=gen_config["top_k"], minimum=0, maximum=50, step=10), | |
gr.Number(label="top_p", value=gen_config["top_p"], minimum=0, maximum=1, step=0.1), | |
gr.Number(label="temperature", value=gen_config["temperature"], minimum=0, maximum=10, step=0.1), | |
gr.Number(label="num_beams", value=gen_config["num_beams"], minimum=0, maximum=5, step=1), | |
gr.Number(label="repetition_penalty", value=gen_config["repetition_penalty"], minimum=0, maximum=5, step=0.1), | |
] | |
# при нажатии Enter в чате будет вызыватся функция generate | |
interface = gr.ChatInterface( | |
generate, | |
chatbot=gr.Chatbot(height=300), # вход для функции generate: message | |
textbox=gr.Textbox(placeholder="Задайте любой вопрос", container=False, scale=2), # выходной бокс для текста | |
# дополнительные входы для функции generate (*components) | |
additional_inputs=components, | |
# настройки оформления | |
title="Чат-бот T10", # название страницы | |
description="Окно переписки с ботом", # описание окошка переписки | |
theme="Glass", # темы: Glass, Monochrome, Soft | |
# examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"], # примеры должны быть множественными если аргументов много | |
# cache_examples=True, # кешировать примеры | |
# дполнительные кнопки (если не нужна какая либо кнопка ставим None) | |
submit_btn='Отправить', | |
retry_btn='Повторить вопрос', | |
undo_btn="Удалить предыдущий вопрос", | |
clear_btn="Очистить историю", | |
) | |