|
import streamlit as st |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
pushkin = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False) |
|
tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2") |
|
return pushkin, tokenizer |
|
|
|
pushkin, tokenizer = load_model() |
|
pushkin.load_state_dict(torch.load('pushkin_weights.pt', map_location=torch.device('cpu'))) |
|
pushkin.to('cpu') |
|
pushkin.eval() |
|
|
|
def generate_response(text, temperature, length, top_p): |
|
input_ids = tokenizer.encode(text, return_tensors="pt") |
|
with torch.no_grad(): |
|
out = pushkin.generate(input_ids, do_sample=True, num_beams=2, temperature=float(temperature), top_p=float(top_p), max_length=length) |
|
generated_text = list(map(tokenizer.decode, out))[0] |
|
last_full_stop_index = generated_text.rfind('.') |
|
st.write(generated_text[:last_full_stop_index + 1]) |
|
|
|
st.title('Александр Сергеевич Пушкин') |
|
st.image('pushkin.jpg', use_column_width=True) |
|
st.write('Напишите подсказку на русском языке, и модель на основе GPT отобразит текст Пушкина.') |
|
|
|
|
|
|
|
with st.expander("Описание"): |
|
st.write("""sberbank-ai/rugpt3small_based_on_gpt2 - это нейронная сеть, специально обученный на большом количестве текстов на русском языке. |
|
Модель может использоваться для создания автоматических ответов, разговорных систем и даже создания |
|
субтитров для видео.""") |
|
st.write("""Мой Dataset состоял из 103_000 слов и обучался 5 эпох (1 час)""") |
|
st.write("""Интересные факты:""") |
|
st.write("""* Модель содержит около 124 миллионов параметров""") |
|
st.write("""* Отличительной особенностью этой модели является ее способность генерировать тексты на различные темы и стили""") |
|
st.write("""* Модель показала высокую точность и удовлетворенность при оценке на разных задачах, таких как вопросы-ответы и перевод текста""") |
|
|
|
|
|
|
|
st.write('Определяем параметры генерации:') |
|
with st.expander("Параметры генерации"): |
|
temperature = st.slider('Температура (Более высокая может способствовать генерации более разнообразных, но менее четких и согласованных фраз)', value=1.5, min_value=1.0, max_value=5.0, step=0.1) |
|
length = st.slider('Длина (определяет ожидаемую длину генерируемого текста)', value=50, min_value=20, max_value=150, step=1) |
|
top_p = st.slider('Значение top-p (более высокое значение top-p, мы получаем более консервативную генерацию, в то время как более низкое значение top-p даёт более разнообразный текст)', value=0.9, min_value=0.5, max_value=1.0, step=0.05) |
|
|
|
|
|
user_input = st.text_area("Введите текст:") |
|
if st.button("Отправить"): |
|
if user_input: |
|
generate_response(user_input, temperature, length, top_p) |
|
else: |
|
st.warning("Пожалуйста, введите текст.") |
|
|