SaviAnna commited on
Commit
8e698f5
1 Parent(s): 9a91c85

Update pages/✨third.py

Browse files
Files changed (1) hide show
  1. pages/✨third.py +69 -111
pages/✨third.py CHANGED
@@ -1,114 +1,72 @@
 
1
  import streamlit as st
2
- import base64
3
- import streamlit as st
4
- import plotly.express as px
5
-
6
- df = px.data.iris()
7
-
8
- @st.cache_data
9
- def get_img_as_base64(file):
10
- with open(file, "rb") as f:
11
- data = f.read()
12
- return base64.b64encode(data).decode()
13
-
14
-
15
- #img = get_img_as_base64("https://catherineasquithgallery.com/uploads/posts/2021-02/1612739741_65-p-goluboi-fon-tsifri-110.jpg")
16
-
17
- page_bg_img = f"""
18
- <style>
19
- [data-testid="stAppViewContainer"] > .main {{
20
- background-image: url("https://wallpapercave.com/wp/wp11966930.jpg");
21
- background-size: 115%;
22
- background-position: top left;
23
- background-repeat: no-repeat;
24
- background-attachment: local;
25
- }}
26
-
27
- [data-testid="stSidebar"] > div:first-child {{
28
- background-image: url("https://ibb.co/ZBkdJRg");
29
- background-size: 115%;
30
- background-position: center;
31
- background-repeat: no-repeat;
32
- background-attachment: fixed;
33
- }}
34
-
35
- [data-testid="stHeader"] {{
36
- background: rgba(0,0,0,0);
37
- }}
38
-
39
- [data-testid="stToolbar"] {{
40
- right: 2rem;
41
- }}
42
-
43
- div.css-1n76uvr.e1tzin5v0 {{
44
- background-color: rgba(238, 238, 238, 0.5);
45
- border: 10px solid #EEEEEE;
46
- padding: 5% 5% 5% 10%;
47
- border-radius: 5px;
48
- }}
49
-
50
- </style>
51
- """
52
- st.markdown(page_bg_img, unsafe_allow_html=True)
53
-
54
- import tensorflow as tf
55
- from tensorflow import keras
56
  import numpy as np
57
- import matplotlib.pyplot as plt
58
-
59
- ################################################################################################
60
- #Тут нужно будет добаить модель. Ниже пример:
61
-
62
- # # Загрузка модели
63
- # model = keras.models.load_model('cgan_model.h5')
64
-
65
- # # Задание размерностей входных данных модели
66
- # latent_dim = 128
67
- # num_classes = 10
68
-
69
- # # Функция для генерации изображения
70
- # def generate_image(number):
71
- # random_latent_vector = tf.random.normal(shape=(1, latent_dim))
72
- # one_hot_label = tf.one_hot([number], num_classes)
73
- # input_data = tf.concat([random_latent_vector, one_hot_label], axis=1)
74
-
75
- # generated_image = model.predict(input_data)
76
- # generated_image = generated_image.reshape(28, 28)
77
- # generated_image = tf.image.resize(generated_image[None, ...], (28, 28))[0] # Добавлено [None, ...] для добавления измерения
78
- # return generated_image
79
-
80
- ################################################################################################
81
-
82
- #Оформление
83
-
84
- col1, col2, col3 = st.columns([1,5,1])
85
- with col2:
86
-
87
- st.title('Название модели')
88
-
89
- col1, col2, col3 = st.columns([2,5,2])
90
- with col2:
91
-
92
- number = st.slider('Выберите число:', 0, 9, step=1)
93
-
94
- ################################################################################################
95
- # Часть, отображаемая на странице
96
-
97
- # number = st.slider('Выберите число:', 0, 9, step=1)
98
-
99
-
100
- # #col1.subheader("Гистограмма total_bill:")
101
-
102
- # # Генерация и отображение изображения
103
- # generated_image = generate_image(number)
104
- # generated_image_np = generated_image.numpy() # Преобразование в массив NumPy
105
- # fig, ax = plt.subplots()
106
- # ax.scatter([1, 2], [1, 2], color='black')
107
- # plt.imshow(generated_image_np, cmap='gray')
108
- # plt.axis('off')
109
- # fig.set_size_inches(3, 3)
110
- # st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- ################################################################################################
113
- #st.markdown("<div style='text-align: center; font-size: 25px;'> ", unsafe_allow_html=True)
114
- #st.markdown("<div style='text-align: center; font-size: 25px;'> ", unsafe_allow_html=True)
 
1
+ import transformers
2
  import streamlit as st
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+
8
+ st.title("""
9
+ History Mistery
10
+ """)
11
+ # image = Image.open('data-scins.jpeg')
12
+
13
+ # st.image(image, caption='Current mood')
14
+ # Добавление слайдера
15
+ temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
16
+ # Загрузка модели и токенизатора
17
+ # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
18
+ # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
19
+ # #Задаем класс модели (уже в streamlit/tg_bot)
20
+ model = GPT2LMHeadModel.from_pretrained(
21
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
22
+ output_attentions = False,
23
+ output_hidden_states = False,
24
+ )
25
+ tokenizer = GPT2Tokenizer.from_pretrained(
26
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
27
+ output_attentions = False,
28
+ output_hidden_states = False,
29
+ )
30
+
31
+ # # Вешаем сохраненные веса на нашу модель
32
+ model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu')))
33
+ # Функция для генерации текста
34
+ def generate_text(prompt):
35
+ # Преобразование входной строки в токены
36
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
37
+
38
+ # Генерация текста
39
+ output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
40
+ temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
41
+ num_return_sequences=3)
42
+
43
+ # Декодирование сгенерированного текста
44
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
45
+
46
+ return generated_text
47
+
48
+ # Streamlit приложение
49
+ def main():
50
+ st.write("""
51
+ # GPT-3 генерация текста
52
+ """)
53
+
54
+ # Ввод строки пользователем
55
+ prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
56
+
57
+ # # Генерация текста по введенной строке
58
+ # generated_text = generate_text(prompt)
59
+ # Создание кнопки "Сгенерировать"
60
+ generate_button = st.button("За работу!")
61
+ # Обработка события нажатия кнопки
62
+ if generate_button:
63
+ # Вывод сгенерированного текста
64
+ generated_text = generate_text(prompt)
65
+ st.subheader("Продолжение:")
66
+ st.write(generated_text)
67
+
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
72