|
import os |
|
|
|
import gradio as gr |
|
from openai import OpenAI |
|
|
|
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." |
|
BASE_URL = os.getenv("BASE_URL") |
|
API_KEY = os.getenv("API_KEY") |
|
MODEL_NAME = "IlyaGusev/saiga_nemo_12b_gptq_8bit" |
|
CLIENT = OpenAI(base_url=BASE_URL, api_key=API_KEY) |
|
|
|
|
|
def user(message, history): |
|
new_history = history + [[message, None]] |
|
return "", new_history |
|
|
|
|
|
def bot( |
|
history, |
|
system_prompt, |
|
top_p, |
|
temp |
|
): |
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
|
|
for user_message, bot_message in history[:-1]: |
|
messages.append({"role": "user", "content": user_message}) |
|
if bot_message: |
|
messages.append({"role": "assistant", "content": bot_message}) |
|
|
|
last_user_message = history[-1][0] |
|
messages.append({"role": "user", "content": last_user_message}) |
|
|
|
response = CLIENT.chat.completions.create( |
|
model=MODEL_NAME, |
|
messages=messages, |
|
temperature=temp, |
|
top_p=top_p, |
|
stream=True, |
|
) |
|
|
|
partial_text = "" |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
partial_text += content |
|
history[-1][1] = partial_text |
|
yield history |
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">' |
|
gr.Markdown( |
|
f"""<h1><center>{favicon}Saiga Nemo 12B GPTQ 8 bit</center></h1> |
|
|
|
This is a demo of a **Russian**-speaking Mistral Nemo based model. |
|
|
|
Это демонстрационная версия [квантованной Сайги Немо с 12 миллиардами параметров](https://huggingface.co/IlyaGusev/saiga_nemo_12b). |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False) |
|
chatbot = gr.Chatbot(label="Диалог") |
|
with gr.Column(min_width=80, scale=1): |
|
with gr.Tab(label="Параметры генерации"): |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
) |
|
temp = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=0.01, |
|
step=0.01, |
|
interactive=True, |
|
label="Температура" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
msg = gr.Textbox( |
|
label="Отправить сообщение", |
|
placeholder="Отправить сообщение", |
|
show_label=False, |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
submit = gr.Button("Отправить") |
|
stop = gr.Button("Остановить") |
|
clear = gr.Button("Очистить") |
|
with gr.Row(): |
|
gr.Markdown( |
|
"""ПРЕДУПРЕЖДЕНИЕ: Модель может генерировать фактически или этически некорректные тексты. Мы не несём за это ответственность.""" |
|
) |
|
|
|
|
|
submit_event = msg.submit( |
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
system_prompt, |
|
top_p, |
|
temp |
|
], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
submit_click_event = submit.click( |
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
system_prompt, |
|
top_p, |
|
temp |
|
], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue(max_size=128) |
|
demo.launch(show_error=True) |
|
|