import streamlit as st from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Загрузка обученной модели и токенизатора model_path = "finetuned" tokenizer = GPT2Tokenizer.from_pretrained(model_path) model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE) def generate_jokes(prompt, temperature, top_p, max_length, num_return_sequences): input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE) # Генерируем несколько шуток outputs = model.generate( input_ids=input_ids, do_sample=True, # num_beams=5, temperature=temperature, top_p=top_p, max_length=max_length, num_return_sequences=num_return_sequences ) # Обработка всех сгенерированных шуток jokes = [] for output in outputs: generated_text = tokenizer.decode(output, skip_special_tokens=True) # Обрезаем текст после первой точки if '…' in generated_text: generated_text = generated_text.split('…')[0] + '.' elif '.' in generated_text: generated_text = generated_text.split('.')[0] + '.' elif '!' in generated_text: generated_text = generated_text.split('!')[0] + '.' jokes.append(generated_text) return jokes # Создание интерфейса Streamlit st.title('GPT-2, как генератор сомнительных шуток') # Ввод промта prompt = st.text_input('Введите свой промт:', 'Народная мудрость гласит') # Регулировка параметров генерации max_length = st.slider('Максимальная длина последовательности:', min_value=10, max_value=100, value=35) num_return_sequences = st.slider('Число генераций текста:', min_value=1, max_value=5, value=3) temperature = st.slider('Температура (дисперсия):', min_value=0.1, max_value=2.0, value=1.0, step=0.1) top_p = st.slider('Top-p (ядро):', min_value=0.1, max_value=1.0, value=0.9, step=0.1) # Генерация текста if st.button('Сгенерировать'): with st.spinner('Генерация текста...'): generated_texts = generate_jokes(prompt,temperature, top_p, max_length, num_return_sequences) for i, text in enumerate(generated_texts): st.subheader(f'Генерация {i + 1}:') st.write(text)