vvv-knyazeva commited on
Commit
13aa874
1 Parent(s): ddeff1f

Update pages/gpt_v1.py

Browse files
Files changed (1) hide show
  1. pages/gpt_v1.py +29 -88
pages/gpt_v1.py CHANGED
@@ -4,56 +4,6 @@ import torch
4
  import textwrap
5
  import plotly.express as px
6
 
7
- df = px.data.iris()
8
-
9
- @st.cache_data
10
- def get_img_as_base64(file):
11
- with open(file, "rb") as f:
12
- data = f.read()
13
- return base64.b64encode(data).decode()
14
-
15
-
16
- #img = get_img_as_base64("https://catherineasquithgallery.com/uploads/posts/2021-02/1612739741_65-p-goluboi-fon-tsifri-110.jpg")
17
-
18
- page_bg_img = f"""
19
- <style>
20
- [data-testid="stAppViewContainer"] > .main {{
21
- background-image: url("https://i.pinimg.com/originals/9f/57/bd/9f57bd45d33eb906fdb3d7ffe22e2058.png");
22
- background-size: 70%;
23
- background-position: top left;
24
- background-repeat: no-repeat;
25
- background-attachment: local;
26
- }}
27
-
28
- # [data-testid="stSidebar"] > div:first-child {{
29
- # background-image: url("https://catherineasquithgallery.com/uploads/posts/2021-02/1614542041_37-p-fon-belii-tekstura-43.jpg");
30
- # background-size: 100%;
31
- # background-position: center;
32
- # background-repeat: no-repeat;
33
- # background-attachment: fixed;
34
- # }}
35
-
36
- [data-testid="stHeader"] {{
37
- background: rgba(0,0,0,0);
38
- }}
39
-
40
- [data-testid="stToolbar"] {{
41
- right: 2rem;
42
- }}
43
-
44
- div.css-1n76uvr.esravye0 {{
45
- background-color: rgba(238, 238, 238, 0.5);
46
- border: 10px solid #EEEEEE;
47
- padding: 5% 5% 5% 10%;
48
- border-radius: 5px;
49
- }}
50
-
51
- </style>
52
- """
53
-
54
- st.markdown(page_bg_img, unsafe_allow_html=True)
55
-
56
-
57
  st.markdown('## Генерация текста GPT-моделью')
58
 
59
  tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
@@ -63,44 +13,35 @@ model = GPT2LMHeadModel.from_pretrained(
63
  output_hidden_states = False,
64
  )
65
  # Вешаем сохраненные веса на нашу модель
66
- model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
67
-
68
-
69
- col1, col2, col3 = st.columns([5, 2, 12])
70
-
71
- with col1:
72
-
73
- length = st.slider('Длина генерируемой последовательности:', 8, 256, 16)
74
- num_samples = st.slider('Число генераций:', 1, 10, 1)
75
- temperature = st.slider('Температура:', 1.0, 10.0, 2.0)
76
- top_k = st.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
77
- top_p = st.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
78
-
79
- with col2:
80
- pass
81
-
82
- with col3:
83
-
84
- prompt = st.text_input('Введите текст:')
85
-
86
- if st.button('Сгенерировать текст'):
87
-
88
- with torch.inference_mode():
89
- prompt = tokenizer.encode(prompt, return_tensors='pt')
90
- out = model.generate(
91
- input_ids=prompt,
92
- max_length=length,
93
- num_beams=5,
94
- do_sample=True,
95
- temperature=temperature,
96
- top_k=top_k,
97
- top_p=top_p,
98
- no_repeat_ngram_size=3,
99
- num_return_sequences=num_samples,
100
- ).cpu().numpy()
101
- for i, out_ in enumerate(out):
102
- st.write(f'Текст {i+1}:')
103
- st.write(textwrap.fill(tokenizer.decode(out_), 100))
104
 
105
 
106
 
 
4
  import textwrap
5
  import plotly.express as px
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  st.markdown('## Генерация текста GPT-моделью')
8
 
9
  tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
 
13
  output_hidden_states = False,
14
  )
15
  # Вешаем сохраненные веса на нашу модель
16
+ model.load_state_dict(torch.load('models/modelgpt.pt', map_location=torch.device('cpu')))
17
+
18
+
19
+ length = st.sidebar.slider('Длина генерируемой последовательности:', 8, 256, 15)
20
+ num_samples = st.sidebar.slider('Число генераций:', 1, 10, 1)
21
+ temperature = st.sidebar.slider('Температура:', 1.0, 10.0, 2.0)
22
+ top_k = st.sidebar.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
23
+ top_p = st.sidebar.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
24
+
25
+
26
+ prompt = st.text_input('Введите текст 👇:')
27
+ if st.button('Сгенерировать текст'):
28
+
29
+ with torch.inference_mode():
30
+ prompt = tokenizer.encode(prompt, return_tensors='pt')
31
+ out = model.generate(
32
+ input_ids=prompt,
33
+ max_length=length,
34
+ num_beams=5,
35
+ do_sample=True,
36
+ temperature=temperature,
37
+ top_k=top_k,
38
+ top_p=top_p,
39
+ no_repeat_ngram_size=3,
40
+ num_return_sequences=num_samples,
41
+ ).cpu().numpy()
42
+ for i, out_ in enumerate(out):
43
+ st.write(f'Текст {i+1}:')
44
+ st.write(textwrap.fill(tokenizer.decode(out_), 100))
 
 
 
 
 
 
 
 
 
45
 
46
 
47