SaviAnna commited on
Commit
3712cee
1 Parent(s): fa9f9fe

Update pages/📖History_Mystery.py

Browse files
Files changed (1) hide show
  1. pages/📖History_Mystery.py +19 -13
pages/📖History_Mystery.py CHANGED
@@ -9,33 +9,38 @@ st.title("""
9
  History Mystery
10
  """)
11
  # Добавление слайдера
12
- temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
13
- max_length = st.slider(" Длина сгенерированного отрывка", 40, 120, 2)
14
  # Загрузка модели и токенизатора
15
  # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
  # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
  # #Задаем класс модели (уже в streamlit/tg_bot)
18
- model_GPT = GPT2LMHeadModel.from_pretrained(
19
- 'sberbank-ai/rugpt3small_based_on_gpt2',
20
- output_attentions = False,
21
- output_hidden_states = False,
22
- )
23
- tokenizer_GPT = GPT2Tokenizer.from_pretrained(
24
  'sberbank-ai/rugpt3small_based_on_gpt2',
25
  output_attentions = False,
26
  output_hidden_states = False,
27
- )
 
 
 
 
 
 
 
28
 
29
  # # Вешаем сохраненные веса на нашу модель
30
- model_GPT.load_state_dict(torch.load('model_history.pt', map_location=torch.device('cpu')))
31
  # Функция для генерации текста
32
- def generate_text(prompt):
33
  # Преобразование входной строки в токены
34
  input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
35
 
36
  # Генерация текста
37
  output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
38
- temperature=5, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
39
  num_return_sequences=3)
40
 
41
  # Декодирование сгенерированного текста
@@ -45,6 +50,7 @@ def generate_text(prompt):
45
 
46
  # Streamlit приложение
47
  def main():
 
48
  st.write("""
49
  # GPT-3 генерация текста
50
  """)
@@ -59,7 +65,7 @@ def main():
59
  # Обработка события нажатия кнопки
60
  if generate_button:
61
  # Вывод сгенерированного текста
62
- generated_text = generate_text(prompt)
63
  st.subheader("Продолжение:")
64
  st.write(generated_text)
65
 
 
9
  History Mystery
10
  """)
11
  # Добавление слайдера
12
+ temperature = st.slider("Градус дичи", 1, 20, 1)
13
+ max_length = st.slider(" Длина сгенерированного отрывка", 60, 120, 2)
14
  # Загрузка модели и токенизатора
15
  # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
  # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
  # #Задаем класс модели (уже в streamlit/tg_bot)
18
+
19
+ @st.cache
20
+ def load_gpt():
21
+ model_GPT = GPT2LMHeadModel.from_pretrained(
 
 
22
  'sberbank-ai/rugpt3small_based_on_gpt2',
23
  output_attentions = False,
24
  output_hidden_states = False,
25
+ )
26
+ tokenizer_GPT = GPT2Tokenizer.from_pretrained(
27
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
28
+ output_attentions = False,
29
+ output_hidden_states = False,
30
+ )
31
+ model_GPT.load_state_dict(torch.load('model_history.pt', map_location=torch.device('cpu')))
32
+ return model_GPT, tokenizer_GPT
33
 
34
  # # Вешаем сохраненные веса на нашу модель
35
+
36
  # Функция для генерации текста
37
+ def generate_text(model_GPT, tokenizer_GPT, prompt):
38
  # Преобразование входной строки в токены
39
  input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
40
 
41
  # Генерация текста
42
  output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
43
+ temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3,
44
  num_return_sequences=3)
45
 
46
  # Декодирование сгенерированного текста
 
50
 
51
  # Streamlit приложение
52
  def main():
53
+ model_GPT, tokenizer_GPT = load_gpt()
54
  st.write("""
55
  # GPT-3 генерация текста
56
  """)
 
65
  # Обработка события нажатия кнопки
66
  if generate_button:
67
  # Вывод сгенерированного текста
68
+ generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
69
  st.subheader("Продолжение:")
70
  st.write(generated_text)
71