vk / app.py
MesonWarrior's picture
Update app.py
7dea28d
raw
history blame
4.02 kB
import gradio as gr
from transformers import pipeline
def generate_text(
model_name,
text,
min_length,
max_length,
temperature,
top_k,
top_p
):
models_map = {
"Мои любимые юморески": "gpt2-vk-aneki",
"бугро тред": "gpt2-vk-bugro",
"Калик)": "gpt2-vk-kalik"
}
model = "MesonWarrior/" + models_map[model_name]
pipe = pipeline(
'text-generation',
model=model,
tokenizer=model,
min_length=min_length,
max_length=max_length
)
return pipe(text, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True)[0]['generated_text']
def interface():
with gr.Row():
with gr.Column():
with gr.Row():
model = gr.Dropdown(
["Мои любимые юморески", "бугро тред", "Калик)"],
label="Модель (Текст какого паблика генерировать)",
value="Мои любимые юморески",
)
text = gr.Textbox(lines=7, label="Входной текст", placeholder="Введите текст который продолжит нейросеть...")
output = gr.Textbox(lines=12, label="Выходной текст", placeholder="Здесь будет текст сгенерированный нейросетью...")
with gr.Row():
with gr.Column():
min_length = gr.Slider(
minimum=0, maximum=100, value=32, step=1,
label="Минимальная длина",
info="Минимальное количество символов в выходном тексте."
)
max_length = gr.Slider(
minimum=0, maximum=200, value=64, step=1,
label="Максимальная длина",
info="Максимальное количество символов в выходном тексте."
)
temperature = gr.Slider(
minimum=0.05, maximum=1.95, value=0.9, step=0.05,
label="Температура",
info="Чем выше тем рандомнее, чем ниже тем больше повторений."
)
top_k = gr.Slider(
minimum=0, maximum=100, value=50, step=0.05,
label="Top K",
)
top_p = gr.Slider(
minimum=0, maximum=1, value=0.9, step=0.05,
label="Top P",
)
with gr.Column():
with gr.Row():
generate_btn = gr.Button(
"Сгенерировать", variant="primary", label="Generate",
)
generate_btn.click(
fn=generate_text,
inputs=[
model,
text,
min_length,
max_length,
temperature,
top_k,
top_p
],
outputs=output,
)
with gr.Blocks(
title="GPT2 VK") as demo:
gr.Markdown("""
# GPT2 VK
Файнтюны [этой](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) модели по вашим любимым пабликам ВКонтакте.
#### Паблики представленные в моделях:
- [Мои любимые юморески 🎩](https://huggingface.co/MesonWarrior/gpt2-vk-aneki)
- [бугро тред 💥](https://huggingface.co/MesonWarrior/gpt2-vk-bugro)
- [Калик) 🍏🍎💨](https://huggingface.co/MesonWarrior/gpt2-vk-kalik) <sub><sup>(Обучено на спорном датасете из постов и комментариев, надо бы переобучить на данных получше)</sup></sub>
""")
interface()
demo.queue().launch()