MARI-posa commited on
Commit
ff099dc
1 Parent(s): 8c11da5

Update stri.py

Browse files
Files changed (1) hide show
  1. stri.py +14 -15
stri.py CHANGED
@@ -45,23 +45,22 @@ max_len = 256
45
  query = st.text_input("Введите запрос")
46
 
47
  if st.button('Сгенерировать'):
48
- with open("book_embeddings.pkl", "rb") as f:
49
  book_embeddings = pickle.load(f)
50
 
51
- query_tokens = tokenizer.encode(query, add_special_tokens=True,
52
- truncation=True, max_length=max_len)
53
-
54
- query_padded = np.array(query_tokens + [0] * (max_len - len(query_tokens)))
55
- query_mask = np.where(query_padded != 0, 1, 0)
56
-
57
- # Переведем numpy массивы в тензоры PyTorch
58
- query_padded = torch.tensor(query_padded, dtype=torch.long)
59
- query_mask = torch.tensor(query_mask, dtype=torch.long)
60
-
61
- with torch.inference_mode():
62
- query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
63
- query_embedding = query_embedding[0][:,0,:]
64
- query_embedding = torch.nn.functional.normalize(query_embedding)
65
 
66
 
67
  # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
 
45
  query = st.text_input("Введите запрос")
46
 
47
  if st.button('Сгенерировать'):
48
+ with open("book_embeddings512.pkl", "rb") as f:
49
  book_embeddings = pickle.load(f)
50
 
51
+ query_tokens = tokenizer.encode_plus(
52
+ query,
53
+ add_special_tokens=True,
54
+ max_length=length, # Ограничение на максимальную длину входной последовательности
55
+ pad_to_max_length=True, # Дополним последовательность нулями до максимальной длины
56
+ return_tensors='pt' # Вернём тензоры PyTorch
57
+ )
58
+
59
+ with torch.no_grad():
60
+ query_outputs = model(**query_tokens)
61
+ query_hidden_states = query_outputs.hidden_states
62
+ query_last_hidden_state = query_hidden_states[-1] # Используем предпоследний слой для эмбеддинга
63
+ query_embedding = torch.mean(query_last_hidden_state, dim=1).squeeze()
 
64
 
65
 
66
  # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией