vvv-knyazeva commited on
Commit
1752023
1 Parent(s): 95a8da8

Update pages/gpt_v1.py

Browse files
Files changed (1) hide show
  1. pages/gpt_v1.py +45 -35
pages/gpt_v1.py CHANGED
@@ -1,47 +1,57 @@
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- import torch
3
  import streamlit as st
4
- import base64
5
- import plotly
6
- import plotly.express as px
7
 
8
- st.markdown('## Генерация текста GPT-моделью по пользовательскому prompt')
9
 
 
10
  model = GPT2LMHeadModel.from_pretrained(
11
  'sberbank-ai/rugpt3small_based_on_gpt2',
12
  output_attentions = False,
13
  output_hidden_states = False,
14
  )
15
- tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
- model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
17
-
18
-
19
- prompt = st.text_input('Введите текст prompt:')
20
- length = st.slider('Длина генерируемой последовательности:', 8, 256, 15)
21
- num_samples = st.slider('Число генераций:', 1, 10, 1)
22
- temperature = st.slider('Температура:', 1.0, 10.0, 2.0)
23
- top_k = st.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
24
- top_p = st.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
25
-
26
-
27
- if st.button('Сгенерировать текст'):
28
-
29
- with torch.inference_mode():
30
- prompt = tokenizer.encode(prompt, return_tensors='pt')
31
- out = model.generate(
32
- input_ids=prompt,
33
- max_length=length,
34
- num_beams=5,
35
- do_sample=True,
36
- temperature=temperature,
37
- top_k=top_k,
38
- top_p=top_p,
39
- no_repeat_ngram_size=3,
40
- num_return_sequences=num_samples,
41
- ).cpu().numpy()
42
- for i, out_ in enumerate(out):
43
- st.write(f'Текст {i+1}:')
44
- st.write(textwrap.fill(tokenizer.decode(out_), 100))
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
 
 
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
2
  import streamlit as st
3
+ import torch
4
+ import textwrap
5
+
6
 
7
+ st.markdown('## Генерация текста GPT-моделью')
8
 
9
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
10
  model = GPT2LMHeadModel.from_pretrained(
11
  'sberbank-ai/rugpt3small_based_on_gpt2',
12
  output_attentions = False,
13
  output_hidden_states = False,
14
  )
15
+ # Вешаем сохраненные веса на нашу модель
16
+ model.load_state_dict(torch.load('modelgpt.pt', map_location=torch.device('cpu')))
17
+
18
+
19
+ col1, col2, col3 = st.columns([4, 3, 10])
20
+
21
+ with col1:
22
+
23
+ length = st.slider('Длина генерируемой последовательности:', 8, 256, 15)
24
+ num_samples = st.slider('Число генераций:', 1, 10, 1)
25
+ temperature = st.slider('Температура:', 1.0, 10.0, 2.0)
26
+ top_k = st.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
27
+ top_p = st.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
28
+
29
+ with col2:
30
+ pass
31
+
32
+ with col3:
33
+
34
+ prompt = st.text_input('Введите текст:')
35
+
36
+ if st.button('Сгенерировать текст'):
37
+
38
+ with torch.inference_mode():
39
+ prompt = tokenizer.encode(prompt, return_tensors='pt')
40
+ out = model.generate(
41
+ input_ids=prompt,
42
+ max_length=length,
43
+ num_beams=5,
44
+ do_sample=True,
45
+ temperature=temperature,
46
+ top_k=top_k,
47
+ top_p=top_p,
48
+ no_repeat_ngram_size=3,
49
+ num_return_sequences=num_samples,
50
+ ).cpu().numpy()
51
+ for i, out_ in enumerate(out):
52
+ st.write(f'Текст {i+1}:')
53
+ st.write(textwrap.fill(tokenizer.decode(out_), 100))
54
+
55
 
56
 
57