|
|
|
import random |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from util_funcs import getLengthParam, calcAnswerLengthByProbability, cropContext |
|
|
|
def chat_function(Message, History): |
|
|
|
input_user = Message |
|
|
|
history = History or [] |
|
|
|
chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long) |
|
|
|
|
|
lengthId = getLengthParam(input_user, tokenizer) |
|
new_user_input_ids = tokenizer.encode(f"|0|{lengthId}|" \ |
|
+ input_user + tokenizer.eos_token, return_tensors="pt") |
|
|
|
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
next_len = calcAnswerLengthByProbability(lengthId) |
|
|
|
|
|
new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt") |
|
|
|
|
|
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
|
|
|
chat_history_ids = cropContext(chat_history_ids, 10) |
|
|
|
print(tokenizer.decode(chat_history_ids[-1])) |
|
|
|
|
|
input_len = chat_history_ids.shape[-1] |
|
|
|
|
|
temperature = 0.6 |
|
|
|
|
|
|
|
|
|
chat_history_ids_initial = chat_history_ids |
|
|
|
while True: |
|
chat_history_ids = model.generate( |
|
chat_history_ids, |
|
num_return_sequences=1, |
|
min_length = 2, |
|
max_length=512, |
|
no_repeat_ngram_size=3, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.9, |
|
temperature = temperature, |
|
mask_token_id=tokenizer.mask_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
unk_token_id=tokenizer.unk_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
device='cpu' |
|
) |
|
|
|
answer = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True) |
|
|
|
if (len(answer) > 0 and answer[-1] != ',' and answer[-1] != ':'): |
|
break |
|
else: |
|
if (temperature <= 0.1): |
|
temperature -= 0.1 |
|
|
|
|
|
chat_history_ids = chat_history_ids_initial |
|
|
|
history.append((input_user, answer, chat_history_ids.tolist())) |
|
html = "<div class='chatbot'>" |
|
for user_msg, resp_msg, _ in history: |
|
if user_msg != '-': |
|
html += f"<div class='user_msg'>{user_msg}</div>" |
|
if resp_msg != '-': |
|
html += f"<div class='resp_msg'>{resp_msg}</div>" |
|
html += "</div>" |
|
return html, history |
|
|
|
|
|
|
|
checkpoint = "avorozhko/ruDialoGpt3-medium-finetuned-context" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
model = model.eval() |
|
|
|
|
|
title = "Чат-бот для поднятия настроения" |
|
description = """ |
|
Данный бот постарается поднять вам настроение, так как он знает 26700 анекдотов. |
|
Но чувство юмора у него весьма специфичное. |
|
Бот не знает матерных слов и откровенных пошлостей, но кто такой Вовочка и Поручик Ржевский знает ) |
|
""" |
|
article = "<p style='text-align: center'><a href='https://huggingface.co/avorozhko/ruDialoGpt3-medium-finetuned-context'>Бот на основе дообученной GPT-3</a></p>" |
|
|
|
iface = gr.Interface(fn=chat_function, |
|
inputs=[gr.inputs.Textbox(lines=3, placeholder="Что вы хотите сказать боту..."), "state"], |
|
outputs=["html", "state"], |
|
title=title, description=description, article=article, |
|
theme='dark-grass', |
|
css= """ |
|
.chatbox {display:flex;flex-direction:column} |
|
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} |
|
.user_msg {background-color:#1e4282;color:white;align-self:start} |
|
.resp_msg {background-color:#552a2a;align-self:self-end} |
|
.panels.unaligned {flex-direction: column !important;align-items: initial!important;} |
|
.panels.unaligned :last-child {order: -1 !important;} |
|
""", |
|
allow_screenshot=False, |
|
allow_flagging='never' |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(debug=True, share=False) |
|
|