import transformers import streamlit as st from transformers import GPT2LMHeadModel, GPT2Tokenizer import numpy as np from PIL import Image import torch st.title(""" History Mistery """) # image = Image.open('data-scins.jpeg') # st.image(image, caption='Current mood') # Добавление слайдера temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0) # Загрузка модели и токенизатора # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') # #Задаем класс модели (уже в streamlit/tg_bot) model = GPT2LMHeadModel.from_pretrained( 'sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False, ) tokenizer = GPT2Tokenizer.from_pretrained( 'sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False, ) # # Вешаем сохраненные веса на нашу модель model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu'))) # Функция для генерации текста def generate_text(prompt): # Преобразование входной строки в токены input_ids = tokenizer.encode(prompt, return_tensors='pt') # Генерация текста output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True, temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3, num_return_sequences=3) # Декодирование сгенерированного текста generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text # Streamlit приложение def main(): st.write(""" # GPT-3 генерация текста """) # Ввод строки пользователем prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси") # # Генерация текста по введенной строке # generated_text = generate_text(prompt) # Создание кнопки "Сгенерировать" generate_button = st.button("За работу!") # Обработка события нажатия кнопки if generate_button: # Вывод сгенерированного текста generated_text = generate_text(prompt) st.subheader("Продолжение:") st.write(generated_text) if __name__ == "__main__": main()