SaviAnna commited on
Commit
fa9f9fe
1 Parent(s): 3c6ea5b

Create History.py

Browse files
Files changed (1) hide show
  1. pages/History.py +72 -0
pages/History.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import streamlit as st
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+
8
+ st.title("""
9
+ History Mistery
10
+ """)
11
+ # image = Image.open('data-scins.jpeg')
12
+
13
+ # st.image(image, caption='Current mood')
14
+ # Добавление слайдера
15
+ temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
16
+ max_length = st.slider("Длина сгенерированного отрывка",40, 120, 40)
17
+ # Загрузка модели и токенизатора
18
+ # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
19
+ # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
20
+ # #Задаем класс модели (уже в streamlit/tg_bot)
21
+ model = GPT2LMHeadModel.from_pretrained(
22
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
23
+ output_attentions = False,
24
+ output_hidden_states = False,
25
+ )
26
+ tokenizer = GPT2Tokenizer.from_pretrained(
27
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
28
+ output_attentions = False,
29
+ output_hidden_states = False,
30
+ )
31
+
32
+ # # Вешаем сохраненные веса на нашу модель
33
+ model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu')))
34
+ # Функция для генерации текста
35
+ def generate_text(prompt):
36
+ # Преобразование входной строки в токены
37
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
38
+
39
+ # Генерация текста
40
+ output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
41
+ temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
42
+ num_return_sequences=3)
43
+
44
+ # Декодирование сгенерированного текста
45
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
46
+
47
+ return generated_text
48
+
49
+ # Streamlit приложение
50
+ def main():
51
+ st.write("""
52
+ # GPT-3 генерация текста
53
+ """)
54
+
55
+ # Ввод строки пользователем
56
+ prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
57
+
58
+ # # Генерация текста по введенной строке
59
+ # generated_text = generate_text(prompt)
60
+ # Создание кнопки "Сгенерировать"
61
+ generate_button = st.button("За работу!")
62
+ # Обработка события нажатия кнопки
63
+ if generate_button:
64
+ # Вывод сгенерированного текста
65
+ generated_text = generate_text(prompt)
66
+ st.subheader("Продолжение:")
67
+ st.write(generated_text)
68
+
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()