Spaces:
Runtime error
Runtime error
Update stri.py
Browse files
stri.py
CHANGED
@@ -17,10 +17,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
17 |
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
|
18 |
|
19 |
# Загрузка датасета и аннотаций к книгам
|
20 |
-
books = pd.read_csv('
|
21 |
books.dropna(inplace=True)
|
22 |
|
23 |
-
books = books[books['annotation'].apply(lambda x: len(x.split()) >=
|
24 |
books.drop_duplicates(subset='title', keep='first', inplace=True)
|
25 |
books = books.reset_index(drop=True)
|
26 |
|
@@ -39,7 +39,7 @@ for i in ['author', 'title', 'annotation']:
|
|
39 |
annot = books['annotation']
|
40 |
|
41 |
# Получение эмбеддингов аннотаций каждой книги в датасете
|
42 |
-
max_len =
|
43 |
|
44 |
# Определение запроса пользователя
|
45 |
query = st.text_input("Введите запрос")
|
@@ -58,9 +58,11 @@ if st.button('Сгенерировать'):
|
|
58 |
query_padded = torch.tensor(query_padded, dtype=torch.long)
|
59 |
query_mask = torch.tensor(query_mask, dtype=torch.long)
|
60 |
|
61 |
-
with torch.
|
62 |
query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
|
63 |
-
query_embedding = query_embedding[0][:,
|
|
|
|
|
64 |
|
65 |
# Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
|
66 |
cosine_similarities = torch.nn.functional.cosine_similarity(
|
@@ -83,4 +85,5 @@ if st.button('Сгенерировать'):
|
|
83 |
response = requests.get(image_url)
|
84 |
image = Image.open(BytesIO(response.content))
|
85 |
cols[0].image(image)
|
|
|
86 |
cols[1].write("---")
|
|
|
17 |
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
|
18 |
|
19 |
# Загрузка датасета и аннотаций к книгам
|
20 |
+
books = pd.read_csv('all+.csv')
|
21 |
books.dropna(inplace=True)
|
22 |
|
23 |
+
books = books[books['annotation'].apply(lambda x: len(x.split()) >= 40)]
|
24 |
books.drop_duplicates(subset='title', keep='first', inplace=True)
|
25 |
books = books.reset_index(drop=True)
|
26 |
|
|
|
39 |
annot = books['annotation']
|
40 |
|
41 |
# Получение эмбеддингов аннотаций каждой книги в датасете
|
42 |
+
max_len = 256
|
43 |
|
44 |
# Определение запроса пользователя
|
45 |
query = st.text_input("Введите запрос")
|
|
|
58 |
query_padded = torch.tensor(query_padded, dtype=torch.long)
|
59 |
query_mask = torch.tensor(query_mask, dtype=torch.long)
|
60 |
|
61 |
+
with torch.inference_mode():
|
62 |
query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
|
63 |
+
query_embedding = query_embedding[0][:,0,:]
|
64 |
+
query_embedding = torch.nn.functional.normalize(query_embedding)
|
65 |
+
|
66 |
|
67 |
# Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
|
68 |
cosine_similarities = torch.nn.functional.cosine_similarity(
|
|
|
85 |
response = requests.get(image_url)
|
86 |
image = Image.open(BytesIO(response.content))
|
87 |
cols[0].image(image)
|
88 |
+
cols[0].write(cosine_similarities[i]:.2f)
|
89 |
cols[1].write("---")
|