MARI-posa commited on
Commit
77204c0
·
1 Parent(s): 0dc6cc7

Upload stri.py

Browse files
Files changed (1) hide show
  1. pages/stri.py +74 -0
pages/stri.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import re
8
+ import pickle
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ st.title("Книжные рекомендации")
13
+
14
+ # Загрузка модели и токенизатора
15
+ model_name = "cointegrated/rubert-tiny2"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
18
+
19
+ # Загрузка датасета и аннотаций к книгам
20
+ books = pd.read_csv('all+++.csv')
21
+ books['author'].fillna('other', inplace=True)
22
+
23
+ annot = books['annotation']
24
+
25
+ # Получение эмбеддингов аннотаций каждой книги в датасете
26
+ length = 256
27
+
28
+ # Определение запроса пользователя
29
+ query = st.text_input("Введите запрос")
30
+
31
+ num_books_per_page = st.selectbox("Количество книг на странице:", [3, 5, 10], index=0)
32
+
33
+ col1, col2 = st.columns(2)
34
+ generate_button = col1.button('Сгенерировать')
35
+
36
+ if generate_button:
37
+ with open("book_embeddings256xxx.pkl", "rb") as f:
38
+ book_embeddings = pickle.load(f)
39
+
40
+ query_tokens = tokenizer.encode_plus(
41
+ query,
42
+ add_special_tokens=True,
43
+ max_length=length, # Ограничение на максимальную длину входной последовательности
44
+ pad_to_max_length=True, # Дополним последовательность нулями до максимальной длины
45
+ return_tensors='pt' # Вернём тензоры PyTorch
46
+ )
47
+
48
+ with torch.no_grad():
49
+ query_outputs = model(**query_tokens)
50
+ query_hidden_states = query_outputs.hidden_states[-1][:, 0, :]
51
+ query_hidden_states = torch.nn.functional.normalize(query_hidden_states)
52
+
53
+ # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
54
+ cosine_similarities = torch.nn.functional.cosine_similarity(
55
+ query_hidden_states.squeeze(0),
56
+ torch.stack(book_embeddings)
57
+ )
58
+
59
+ cosine_similarities = cosine_similarities.numpy()
60
+
61
+ indices = np.argsort(cosine_similarities)[::-1] # Сортировка по убыванию
62
+
63
+ for i in indices[:num_books_per_page]:
64
+ cols = st.columns(2) # Создание двух столбцов для размещения информации и изображения
65
+ cols[1].write("## " + books['title'][i])
66
+ cols[1].markdown("**Автор:** " + books['author'][i])
67
+ cols[1].markdown("**Аннотация:** " + books['annotation'][i])
68
+ image_url = books['image_url'][i]
69
+
70
+ response = requests.get(image_url)
71
+ image = Image.open(BytesIO(response.content))
72
+ cols[0].image(image)
73
+ cols[0].write(cosine_similarities[i])
74
+ cols[1].write("---")