Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import torch | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Загрузка обученной модели и токенизатора | |
model_path = "finetuned" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_path) | |
model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE) | |
def generate_jokes(prompt, temperature, top_p, max_length, num_return_sequences): | |
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE) | |
# Генерируем несколько шуток | |
outputs = model.generate( | |
input_ids=input_ids, | |
do_sample=True, | |
# num_beams=5, | |
temperature=temperature, | |
top_p=top_p, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences | |
) | |
# Обработка всех сгенерированных шуток | |
jokes = [] | |
for output in outputs: | |
generated_text = tokenizer.decode(output, skip_special_tokens=True) | |
# Обрезаем текст после первой точки | |
if '…' in generated_text: | |
generated_text = generated_text.split('…')[0] + '.' | |
elif '.' in generated_text: | |
generated_text = generated_text.split('.')[0] + '.' | |
elif '!' in generated_text: | |
generated_text = generated_text.split('!')[0] + '.' | |
jokes.append(generated_text) | |
return jokes | |
# Создание интерфейса Streamlit | |
st.title('GPT-2, как генератор сомнительных шуток') | |
# Ввод промта | |
prompt = st.text_input('Введите свой промт:', 'Народная мудрость гласит') | |
# Регулировка параметров генерации | |
max_length = st.slider('Максимальная длина последовательности:', min_value=10, max_value=100, value=35) | |
num_return_sequences = st.slider('Число генераций текста:', min_value=1, max_value=5, value=3) | |
temperature = st.slider('Температура (дисперсия):', min_value=0.1, max_value=2.0, value=1.0, step=0.1) | |
top_p = st.slider('Top-p (ядро):', min_value=0.1, max_value=1.0, value=0.9, step=0.1) | |
# Генерация текста | |
if st.button('Сгенерировать'): | |
with st.spinner('Генерация текста...'): | |
generated_texts = generate_jokes(prompt,temperature, top_p, max_length, num_return_sequences) | |
for i, text in enumerate(generated_texts): | |
st.subheader(f'Генерация {i + 1}:') | |
st.write(text) |