import gradio as gr from transformers import pipeline def generate_text( model_name, text, min_length, max_length, temperature, top_k, top_p ): models_map = { "Мои любимые юморески": "gpt2-vk-aneki", "бугро тред": "gpt2-vk-bugro", "Калик)": "gpt2-vk-kalik" } model = "MesonWarrior/" + models_map[model_name] pipe = pipeline( 'text-generation', model=model, tokenizer=model, min_length=min_length, max_length=max_length ) return pipe(text, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True)[0]['generated_text'] def interface(): with gr.Row(): with gr.Column(): with gr.Row(): model = gr.Dropdown( ["Мои любимые юморески", "бугро тред", "Калик)"], label="Модель (Текст какого паблика генерировать)", value="Мои любимые юморески", ) text = gr.Textbox(lines=7, label="Входной текст", placeholder="Введите текст который продолжит нейросеть...") output = gr.Textbox(lines=12, label="Выходной текст", placeholder="Здесь будет текст сгенерированный нейросетью...") with gr.Row(): with gr.Column(): min_length = gr.Slider( minimum=0, maximum=100, value=32, step=1, label="Минимальная длина", info="Минимальное количество символов в выходном тексте." ) max_length = gr.Slider( minimum=0, maximum=200, value=64, step=1, label="Максимальная длина", info="Максимальное количество символов в выходном тексте." ) temperature = gr.Slider( minimum=0.05, maximum=1.95, value=0.9, step=0.05, label="Температура", info="Чем выше тем рандомнее, чем ниже тем больше повторений." ) top_k = gr.Slider( minimum=0, maximum=100, value=50, step=0.05, label="Top K", ) top_p = gr.Slider( minimum=0, maximum=1, value=0.9, step=0.05, label="Top P", ) with gr.Column(): with gr.Row(): generate_btn = gr.Button( "Сгенерировать", variant="primary", label="Generate", ) generate_btn.click( fn=generate_text, inputs=[ model, text, min_length, max_length, temperature, top_k, top_p ], outputs=output, ) with gr.Blocks( title="GPT2 VK") as demo: gr.Markdown(""" # GPT2 VK Файнтюны [этой](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) модели по вашим любимым пабликам ВКонтакте. #### Паблики представленные в моделях: - [Мои любимые юморески 🎩](https://huggingface.co/MesonWarrior/gpt2-vk-aneki) - [бугро тред 💥](https://huggingface.co/MesonWarrior/gpt2-vk-bugro) - [Калик) 🍏🍎💨](https://huggingface.co/MesonWarrior/gpt2-vk-kalik) (Обучено на спорном датасете из постов и комментариев, надо бы переобучить на данных получше) """) interface() demo.queue().launch()