ACMCMC commited on
Commit
b0c103a
1 Parent(s): 646d392

restore file

Browse files
Files changed (3) hide show
  1. Intersystems.ipynb +1782 -0
  2. database.ipynb +6 -6
  3. graph_visualization.mlapp +0 -0
Intersystems.ipynb ADDED
@@ -0,0 +1,1782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9db3b813-22dc-4209-86d2-42e935f5f5dd",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from langchain_community.document_loaders.csv_loader import CSVLoader\n",
11
+ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
12
+ "import pandas as pd\n",
13
+ "import langchain\n",
14
+ "import os\n",
15
+ "import openai\n",
16
+ "import ast\n",
17
+ "from langchain import OpenAI\n",
18
+ "from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain\n",
19
+ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
20
+ "from langchain.document_loaders import UnstructuredURLLoader\n",
21
+ "from langchain.embeddings import OpenAIEmbeddings\n",
22
+ "from langchain.vectorstores import FAISS"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 9,
28
+ "id": "fb48dccd-37e5-484a-a2fc-c482839b9ed9",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# loader = CSVLoader(file_path=\"trials/brief_summaries.csv\")\n",
33
+ "# data = loader.load()"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 10,
39
+ "id": "cdcc3107-e6a8-47ca-bd89-e381fbcf9b9e",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "# df= pd.read_csv(\"trials/brief_summaries.txt\", delimiter=\"|\")\n",
44
+ "# df.shape"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 11,
50
+ "id": "c6a94c14-b197-4eff-9941-1ca52069cd5c",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "# df.head(20)"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "id": "95c000ff-bf0c-4489-8643-93a238db41dc",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "os.environ['OPENAI_API_KEY']=\"sk-proj-CG2E98bSWs53X2eWO0Z4T3BlbkFJLm7H1vfkbua0zP548CKQ\""
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 3,
70
+ "id": "e2e03936-fcce-4287-bfe6-e31f1b69f693",
71
+ "metadata": {},
72
+ "outputs": [
73
+ {
74
+ "name": "stderr",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "/Users/aldan.creo/miniconda3/envs/hackupc/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `OpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import OpenAI`.\n",
78
+ " warn_deprecated(\n"
79
+ ]
80
+ }
81
+ ],
82
+ "source": [
83
+ "llm= OpenAI(\n",
84
+ " temperature=0.6, max_tokens=500\n",
85
+ ")"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 14,
91
+ "id": "aede3d75-2441-44f6-b3cf-2f86f050da24",
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "(440517, 2)\n"
99
+ ]
100
+ },
101
+ {
102
+ "data": {
103
+ "text/html": [
104
+ "<div>\n",
105
+ "<style scoped>\n",
106
+ " .dataframe tbody tr th:only-of-type {\n",
107
+ " vertical-align: middle;\n",
108
+ " }\n",
109
+ "\n",
110
+ " .dataframe tbody tr th {\n",
111
+ " vertical-align: top;\n",
112
+ " }\n",
113
+ "\n",
114
+ " .dataframe thead th {\n",
115
+ " text-align: right;\n",
116
+ " }\n",
117
+ "</style>\n",
118
+ "<table border=\"1\" class=\"dataframe\">\n",
119
+ " <thead>\n",
120
+ " <tr style=\"text-align: right;\">\n",
121
+ " <th></th>\n",
122
+ " <th>desease_condition</th>\n",
123
+ " <th>text</th>\n",
124
+ " </tr>\n",
125
+ " </thead>\n",
126
+ " <tbody>\n",
127
+ " <tr>\n",
128
+ " <th>0</th>\n",
129
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
130
+ " <td>nct_id: NCT03055377\\nsummary: This is a 12-wee...</td>\n",
131
+ " </tr>\n",
132
+ " <tr>\n",
133
+ " <th>1</th>\n",
134
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
135
+ " <td>nct_id: NCT03055377\\nsummary: This is a 12-wee...</td>\n",
136
+ " </tr>\n",
137
+ " <tr>\n",
138
+ " <th>2</th>\n",
139
+ " <td>['tuberculosis', 'latent tuberculosis', 'infec...</td>\n",
140
+ " <td>nct_id: NCT03042754\\nsummary: Early diagnosis ...</td>\n",
141
+ " </tr>\n",
142
+ " <tr>\n",
143
+ " <th>3</th>\n",
144
+ " <td>['heart failure', 'heart diseases', 'cardiovas...</td>\n",
145
+ " <td>nct_id: NCT03035123\\nsummary: The EduStra-HF s...</td>\n",
146
+ " </tr>\n",
147
+ " <tr>\n",
148
+ " <th>4</th>\n",
149
+ " <td>['lymphoma', 'neoplasms by histologic type', '...</td>\n",
150
+ " <td>nct_id: NCT02272751\\nsummary: This study will ...</td>\n",
151
+ " </tr>\n",
152
+ " </tbody>\n",
153
+ "</table>\n",
154
+ "</div>"
155
+ ],
156
+ "text/plain": [
157
+ " desease_condition \\\n",
158
+ "0 ['marijuana abuse', 'substance-related disorde... \n",
159
+ "1 ['marijuana abuse', 'substance-related disorde... \n",
160
+ "2 ['tuberculosis', 'latent tuberculosis', 'infec... \n",
161
+ "3 ['heart failure', 'heart diseases', 'cardiovas... \n",
162
+ "4 ['lymphoma', 'neoplasms by histologic type', '... \n",
163
+ "\n",
164
+ " text \n",
165
+ "0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
166
+ "1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
167
+ "2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n",
168
+ "3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n",
169
+ "4 nct_id: NCT02272751\\nsummary: This study will ... "
170
+ ]
171
+ },
172
+ "execution_count": 14,
173
+ "metadata": {},
174
+ "output_type": "execute_result"
175
+ }
176
+ ],
177
+ "source": [
178
+ "df_trials= pd.read_csv(\"clinical_trials.csv\")\n",
179
+ "print(df_trials.shape)\n",
180
+ "df_trials.head()"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 15,
186
+ "id": "65589fc4-4f55-4b72-9935-c3b163fde0e2",
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "data": {
191
+ "text/html": [
192
+ "<div>\n",
193
+ "<style scoped>\n",
194
+ " .dataframe tbody tr th:only-of-type {\n",
195
+ " vertical-align: middle;\n",
196
+ " }\n",
197
+ "\n",
198
+ " .dataframe tbody tr th {\n",
199
+ " vertical-align: top;\n",
200
+ " }\n",
201
+ "\n",
202
+ " .dataframe thead th {\n",
203
+ " text-align: right;\n",
204
+ " }\n",
205
+ "</style>\n",
206
+ "<table border=\"1\" class=\"dataframe\">\n",
207
+ " <thead>\n",
208
+ " <tr style=\"text-align: right;\">\n",
209
+ " <th></th>\n",
210
+ " <th>desease_condition</th>\n",
211
+ " </tr>\n",
212
+ " </thead>\n",
213
+ " <tbody>\n",
214
+ " <tr>\n",
215
+ " <th>0</th>\n",
216
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <th>1</th>\n",
220
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <th>2</th>\n",
224
+ " <td>['tuberculosis', 'latent tuberculosis', 'infec...</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <th>3</th>\n",
228
+ " <td>['heart failure', 'heart diseases', 'cardiovas...</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <th>4</th>\n",
232
+ " <td>['lymphoma', 'neoplasms by histologic type', '...</td>\n",
233
+ " </tr>\n",
234
+ " </tbody>\n",
235
+ "</table>\n",
236
+ "</div>"
237
+ ],
238
+ "text/plain": [
239
+ " desease_condition\n",
240
+ "0 ['marijuana abuse', 'substance-related disorde...\n",
241
+ "1 ['marijuana abuse', 'substance-related disorde...\n",
242
+ "2 ['tuberculosis', 'latent tuberculosis', 'infec...\n",
243
+ "3 ['heart failure', 'heart diseases', 'cardiovas...\n",
244
+ "4 ['lymphoma', 'neoplasms by histologic type', '..."
245
+ ]
246
+ },
247
+ "execution_count": 15,
248
+ "metadata": {},
249
+ "output_type": "execute_result"
250
+ }
251
+ ],
252
+ "source": [
253
+ "df_trials_filtered=df_trials[['desease_condition']]\n",
254
+ "df_trials_filtered.head()"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 16,
260
+ "id": "88e28056-e340-416c-a1a9-4a6c29556dc7",
261
+ "metadata": {},
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "\"['marijuana abuse', 'substance-related disorders', 'chemically-induced disorders', 'mental disorders']\""
267
+ ]
268
+ },
269
+ "execution_count": 16,
270
+ "metadata": {},
271
+ "output_type": "execute_result"
272
+ }
273
+ ],
274
+ "source": [
275
+ "df_trials_filtered['desease_condition'].iloc[0]"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 17,
281
+ "id": "5bd8f876-0480-40a5-a32f-ca7ec137a70f",
282
+ "metadata": {},
283
+ "outputs": [
284
+ {
285
+ "name": "stderr",
286
+ "output_type": "stream",
287
+ "text": [
288
+ "C:\\Users\\ariji\\AppData\\Local\\Temp\\ipykernel_22340\\16068817.py:4: SettingWithCopyWarning: \n",
289
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
290
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
291
+ "\n",
292
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
293
+ " df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n"
294
+ ]
295
+ },
296
+ {
297
+ "data": {
298
+ "text/html": [
299
+ "<div>\n",
300
+ "<style scoped>\n",
301
+ " .dataframe tbody tr th:only-of-type {\n",
302
+ " vertical-align: middle;\n",
303
+ " }\n",
304
+ "\n",
305
+ " .dataframe tbody tr th {\n",
306
+ " vertical-align: top;\n",
307
+ " }\n",
308
+ "\n",
309
+ " .dataframe thead th {\n",
310
+ " text-align: right;\n",
311
+ " }\n",
312
+ "</style>\n",
313
+ "<table border=\"1\" class=\"dataframe\">\n",
314
+ " <thead>\n",
315
+ " <tr style=\"text-align: right;\">\n",
316
+ " <th></th>\n",
317
+ " <th>desease_condition</th>\n",
318
+ " </tr>\n",
319
+ " </thead>\n",
320
+ " <tbody>\n",
321
+ " <tr>\n",
322
+ " <th>0</th>\n",
323
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
324
+ " </tr>\n",
325
+ " <tr>\n",
326
+ " <th>1</th>\n",
327
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
328
+ " </tr>\n",
329
+ " <tr>\n",
330
+ " <th>2</th>\n",
331
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <th>3</th>\n",
335
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
336
+ " </tr>\n",
337
+ " <tr>\n",
338
+ " <th>4</th>\n",
339
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
340
+ " </tr>\n",
341
+ " <tr>\n",
342
+ " <th>...</th>\n",
343
+ " <td>...</td>\n",
344
+ " </tr>\n",
345
+ " <tr>\n",
346
+ " <th>440512</th>\n",
347
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
348
+ " </tr>\n",
349
+ " <tr>\n",
350
+ " <th>440513</th>\n",
351
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <th>440514</th>\n",
355
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
356
+ " </tr>\n",
357
+ " <tr>\n",
358
+ " <th>440515</th>\n",
359
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
360
+ " </tr>\n",
361
+ " <tr>\n",
362
+ " <th>440516</th>\n",
363
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
364
+ " </tr>\n",
365
+ " </tbody>\n",
366
+ "</table>\n",
367
+ "<p>440517 rows × 1 columns</p>\n",
368
+ "</div>"
369
+ ],
370
+ "text/plain": [
371
+ " desease_condition\n",
372
+ "0 marijuana abuse, substance-related disorders, ...\n",
373
+ "1 marijuana abuse, substance-related disorders, ...\n",
374
+ "2 tuberculosis, latent tuberculosis, infections,...\n",
375
+ "3 heart failure, heart diseases, cardiovascular ...\n",
376
+ "4 lymphoma, neoplasms by histologic type, neopla...\n",
377
+ "... ...\n",
378
+ "440512 obesity, overweight, overnutrition, nutrition ...\n",
379
+ "440513 obesity, overweight, overnutrition, nutrition ...\n",
380
+ "440514 obesity, overweight, overnutrition, nutrition ...\n",
381
+ "440515 autistic disorder, autism spectrum disorder, c...\n",
382
+ "440516 autistic disorder, autism spectrum disorder, c...\n",
383
+ "\n",
384
+ "[440517 rows x 1 columns]"
385
+ ]
386
+ },
387
+ "execution_count": 17,
388
+ "metadata": {},
389
+ "output_type": "execute_result"
390
+ }
391
+ ],
392
+ "source": [
393
+ "def list_to_string(disease_list):\n",
394
+ " disease_list= ast.literal_eval(disease_list)\n",
395
+ " return ', '.join(disease_list)\n",
396
+ "df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n",
397
+ "df_trials_filtered"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 18,
403
+ "id": "bbbb22e6-4883-4869-8ccc-95696bc67b1b",
404
+ "metadata": {},
405
+ "outputs": [
406
+ {
407
+ "data": {
408
+ "text/plain": [
409
+ "0 marijuana abuse, substance-related disorders, ...\n",
410
+ "1 marijuana abuse, substance-related disorders, ...\n",
411
+ "2 tuberculosis, latent tuberculosis, infections,...\n",
412
+ "3 heart failure, heart diseases, cardiovascular ...\n",
413
+ "4 lymphoma, neoplasms by histologic type, neopla...\n",
414
+ "Name: desease_condition, dtype: object"
415
+ ]
416
+ },
417
+ "execution_count": 18,
418
+ "metadata": {},
419
+ "output_type": "execute_result"
420
+ }
421
+ ],
422
+ "source": [
423
+ "df_trials_filtered['desease_condition'].head()"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": 19,
429
+ "id": "7f4e7ceb-8bfd-4294-a850-8935f88b6555",
430
+ "metadata": {},
431
+ "outputs": [],
432
+ "source": [
433
+ "df_trials_filtered.to_csv(\"diseases.csv\", index=False)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": 20,
439
+ "id": "af1c5c2b-24a0-44a1-9e5d-7ee89ca4cccf",
440
+ "metadata": {},
441
+ "outputs": [
442
+ {
443
+ "data": {
444
+ "text/plain": [
445
+ "440517"
446
+ ]
447
+ },
448
+ "execution_count": 20,
449
+ "metadata": {},
450
+ "output_type": "execute_result"
451
+ }
452
+ ],
453
+ "source": [
454
+ "loader= CSVLoader(file_path=\"./diseases.csv\", encoding=\"utf-8\")\n",
455
+ "data = loader.load()\n",
456
+ "len(data)"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "id": "cab89218-41ca-4048-886d-bc2c1c9b30bc",
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "embeddings = OpenAIEmbeddings()\n",
467
+ "vectorstore = FAISS.from_documents(data, embeddings)"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "id": "225ade4a-d004-44cc-a5ff-22ce2bfcac32",
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": [
477
+ "file_path= \"vector_index.pkl\"\n",
478
+ "with open(file_path, \"wb\") as f:\n",
479
+ " pickle.dump(vectorstore, f)"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": 98,
485
+ "id": "11912a93-ad02-41cb-8bce-2750c947fa74",
486
+ "metadata": {},
487
+ "outputs": [
488
+ {
489
+ "name": "stdout",
490
+ "output_type": "stream",
491
+ "text": [
492
+ "(440517, 2)\n"
493
+ ]
494
+ },
495
+ {
496
+ "data": {
497
+ "text/html": [
498
+ "<div>\n",
499
+ "<style scoped>\n",
500
+ " .dataframe tbody tr th:only-of-type {\n",
501
+ " vertical-align: middle;\n",
502
+ " }\n",
503
+ "\n",
504
+ " .dataframe tbody tr th {\n",
505
+ " vertical-align: top;\n",
506
+ " }\n",
507
+ "\n",
508
+ " .dataframe thead th {\n",
509
+ " text-align: right;\n",
510
+ " }\n",
511
+ "</style>\n",
512
+ "<table border=\"1\" class=\"dataframe\">\n",
513
+ " <thead>\n",
514
+ " <tr style=\"text-align: right;\">\n",
515
+ " <th></th>\n",
516
+ " <th>desease_condition</th>\n",
517
+ " <th>text</th>\n",
518
+ " </tr>\n",
519
+ " </thead>\n",
520
+ " <tbody>\n",
521
+ " <tr>\n",
522
+ " <th>0</th>\n",
523
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
524
+ " <td>nct_id: NCT03055377\\nsummary: This is a 12-wee...</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <th>1</th>\n",
528
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
529
+ " <td>nct_id: NCT03055377\\nsummary: This is a 12-wee...</td>\n",
530
+ " </tr>\n",
531
+ " <tr>\n",
532
+ " <th>2</th>\n",
533
+ " <td>['tuberculosis', 'latent tuberculosis', 'infec...</td>\n",
534
+ " <td>nct_id: NCT03042754\\nsummary: Early diagnosis ...</td>\n",
535
+ " </tr>\n",
536
+ " <tr>\n",
537
+ " <th>3</th>\n",
538
+ " <td>['heart failure', 'heart diseases', 'cardiovas...</td>\n",
539
+ " <td>nct_id: NCT03035123\\nsummary: The EduStra-HF s...</td>\n",
540
+ " </tr>\n",
541
+ " <tr>\n",
542
+ " <th>4</th>\n",
543
+ " <td>['lymphoma', 'neoplasms by histologic type', '...</td>\n",
544
+ " <td>nct_id: NCT02272751\\nsummary: This study will ...</td>\n",
545
+ " </tr>\n",
546
+ " </tbody>\n",
547
+ "</table>\n",
548
+ "</div>"
549
+ ],
550
+ "text/plain": [
551
+ " desease_condition \\\n",
552
+ "0 ['marijuana abuse', 'substance-related disorde... \n",
553
+ "1 ['marijuana abuse', 'substance-related disorde... \n",
554
+ "2 ['tuberculosis', 'latent tuberculosis', 'infec... \n",
555
+ "3 ['heart failure', 'heart diseases', 'cardiovas... \n",
556
+ "4 ['lymphoma', 'neoplasms by histologic type', '... \n",
557
+ "\n",
558
+ " text \n",
559
+ "0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
560
+ "1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
561
+ "2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n",
562
+ "3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n",
563
+ "4 nct_id: NCT02272751\\nsummary: This study will ... "
564
+ ]
565
+ },
566
+ "execution_count": 98,
567
+ "metadata": {},
568
+ "output_type": "execute_result"
569
+ }
570
+ ],
571
+ "source": [
572
+ "df_trials= pd.read_csv(\"clinical_trials.csv\")\n",
573
+ "print(df_trials.shape)\n",
574
+ "df_trials.head()"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "code",
579
+ "execution_count": 99,
580
+ "id": "a8113876-f38e-4c7a-891e-17dc51e2bacf",
581
+ "metadata": {},
582
+ "outputs": [
583
+ {
584
+ "data": {
585
+ "text/plain": [
586
+ "\"nct_id: NCT03055377\\nsummary: This is a 12-week randomized, placebo-controlled trial of N-acetylcysteine for cannabis use disorder (CUD) in youth (N=192). Participants will be randomized to double-blind NAC or PBO, yielding two equally-allocated treatment groups. All participants will receive brief weekly cannabis cessation counseling and medication management. The primary efficacy outcome will be the proportion of negative urine cannabinoid tests during the 12-week active treatment, compared between groups.\\nintervention_type: Drug\\nintervention_name: N-acetyl cysteine\\nintervention_description: N-acetylcysteine 1200 mg twice daily for 12 weeks (administered orally)\\nkeywords: ['cannabis', 'marijuana', 'youth', 'adolescent', 'pharmacotherapy', 'medication', 'n-acetylcysteine', 'trial']\""
587
+ ]
588
+ },
589
+ "execution_count": 99,
590
+ "metadata": {},
591
+ "output_type": "execute_result"
592
+ }
593
+ ],
594
+ "source": [
595
+ "df_trials.iloc[0].text"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "code",
600
+ "execution_count": 100,
601
+ "id": "c9ace03c-cf71-4605-adeb-9c6ee10e158a",
602
+ "metadata": {},
603
+ "outputs": [
604
+ {
605
+ "data": {
606
+ "text/plain": [
607
+ "(1000, 2)"
608
+ ]
609
+ },
610
+ "execution_count": 100,
611
+ "metadata": {},
612
+ "output_type": "execute_result"
613
+ }
614
+ ],
615
+ "source": [
616
+ "df= df_trials[:1000]\n",
617
+ "df.shape"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": null,
623
+ "id": "85a985c5-845b-464e-9527-e806b6883741",
624
+ "metadata": {},
625
+ "outputs": [],
626
+ "source": [
627
+ "embeddings= OpenAIEmbeddings()\n",
628
+ "vectorindex_openapi= FAISS.from"
629
+ ]
630
+ },
631
+ {
632
+ "cell_type": "code",
633
+ "execution_count": null,
634
+ "id": "25b31e55-2961-474d-92d8-5963f2c6bf84",
635
+ "metadata": {},
636
+ "outputs": [],
637
+ "source": []
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "execution_count": null,
642
+ "id": "918c078c-46fe-4d7b-9748-88c52a5b004a",
643
+ "metadata": {},
644
+ "outputs": [],
645
+ "source": []
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": null,
650
+ "id": "36e83202-97ad-425d-95ae-075a1e26a34e",
651
+ "metadata": {},
652
+ "outputs": [],
653
+ "source": []
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": null,
658
+ "id": "5d705875-5dd7-4c71-8d94-99c101020ac0",
659
+ "metadata": {},
660
+ "outputs": [],
661
+ "source": []
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": null,
666
+ "id": "50c15b4d-9f65-4385-9aa6-3f782bf57775",
667
+ "metadata": {},
668
+ "outputs": [],
669
+ "source": []
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": 6,
674
+ "id": "f2818abe-1a43-4d7d-92a7-7562812bf43d",
675
+ "metadata": {},
676
+ "outputs": [
677
+ {
678
+ "data": {
679
+ "text/html": [
680
+ "<div>\n",
681
+ "<style scoped>\n",
682
+ " .dataframe tbody tr th:only-of-type {\n",
683
+ " vertical-align: middle;\n",
684
+ " }\n",
685
+ "\n",
686
+ " .dataframe tbody tr th {\n",
687
+ " vertical-align: top;\n",
688
+ " }\n",
689
+ "\n",
690
+ " .dataframe thead th {\n",
691
+ " text-align: right;\n",
692
+ " }\n",
693
+ "</style>\n",
694
+ "<table border=\"1\" class=\"dataframe\">\n",
695
+ " <thead>\n",
696
+ " <tr style=\"text-align: right;\">\n",
697
+ " <th></th>\n",
698
+ " <th>desease_condition</th>\n",
699
+ " </tr>\n",
700
+ " </thead>\n",
701
+ " <tbody>\n",
702
+ " <tr>\n",
703
+ " <th>0</th>\n",
704
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
705
+ " </tr>\n",
706
+ " <tr>\n",
707
+ " <th>1</th>\n",
708
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
709
+ " </tr>\n",
710
+ " <tr>\n",
711
+ " <th>2</th>\n",
712
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
713
+ " </tr>\n",
714
+ " <tr>\n",
715
+ " <th>3</th>\n",
716
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
717
+ " </tr>\n",
718
+ " <tr>\n",
719
+ " <th>4</th>\n",
720
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
721
+ " </tr>\n",
722
+ " <tr>\n",
723
+ " <th>...</th>\n",
724
+ " <td>...</td>\n",
725
+ " </tr>\n",
726
+ " <tr>\n",
727
+ " <th>440512</th>\n",
728
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
729
+ " </tr>\n",
730
+ " <tr>\n",
731
+ " <th>440513</th>\n",
732
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
733
+ " </tr>\n",
734
+ " <tr>\n",
735
+ " <th>440514</th>\n",
736
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
737
+ " </tr>\n",
738
+ " <tr>\n",
739
+ " <th>440515</th>\n",
740
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
741
+ " </tr>\n",
742
+ " <tr>\n",
743
+ " <th>440516</th>\n",
744
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
745
+ " </tr>\n",
746
+ " </tbody>\n",
747
+ "</table>\n",
748
+ "<p>440517 rows × 1 columns</p>\n",
749
+ "</div>"
750
+ ],
751
+ "text/plain": [
752
+ " desease_condition\n",
753
+ "0 marijuana abuse, substance-related disorders, ...\n",
754
+ "1 marijuana abuse, substance-related disorders, ...\n",
755
+ "2 tuberculosis, latent tuberculosis, infections,...\n",
756
+ "3 heart failure, heart diseases, cardiovascular ...\n",
757
+ "4 lymphoma, neoplasms by histologic type, neopla...\n",
758
+ "... ...\n",
759
+ "440512 obesity, overweight, overnutrition, nutrition ...\n",
760
+ "440513 obesity, overweight, overnutrition, nutrition ...\n",
761
+ "440514 obesity, overweight, overnutrition, nutrition ...\n",
762
+ "440515 autistic disorder, autism spectrum disorder, c...\n",
763
+ "440516 autistic disorder, autism spectrum disorder, c...\n",
764
+ "\n",
765
+ "[440517 rows x 1 columns]"
766
+ ]
767
+ },
768
+ "execution_count": 6,
769
+ "metadata": {},
770
+ "output_type": "execute_result"
771
+ }
772
+ ],
773
+ "source": [
774
+ "df_trials_filtered = pd.read_csv(\"diseases.csv\")\n",
775
+ "df_trials_filtered"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": 49,
781
+ "id": "c89e3cf6-a376-4029-9c04-0f5e664a2237",
782
+ "metadata": {},
783
+ "outputs": [
784
+ {
785
+ "data": {
786
+ "text/plain": [
787
+ "(100, 1)"
788
+ ]
789
+ },
790
+ "execution_count": 49,
791
+ "metadata": {},
792
+ "output_type": "execute_result"
793
+ }
794
+ ],
795
+ "source": [
796
+ "df2= df_trials_filtered[:100]\n",
797
+ "df2.shape"
798
+ ]
799
+ },
800
+ {
801
+ "cell_type": "code",
802
+ "execution_count": 7,
803
+ "id": "c5012bcf-3e25-4f21-a29c-6bdbdafbb8c7",
804
+ "metadata": {},
805
+ "outputs": [],
806
+ "source": [
807
+ "from openai import OpenAI\n",
808
+ "client= OpenAI()"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "execution_count": 8,
814
+ "id": "40a480bd-6754-40b6-870c-42d10ce9a960",
815
+ "metadata": {},
816
+ "outputs": [],
817
+ "source": [
818
+ "def get_embeddings(text):\n",
819
+ " response= client.embeddings.create(\n",
820
+ " input= text,\n",
821
+ " dimensions= 128,\n",
822
+ " model= \"text-embedding-3-small\"\n",
823
+ " )\n",
824
+ " return response.data[0].embedding"
825
+ ]
826
+ },
827
+ {
828
+ "cell_type": "code",
829
+ "execution_count": 9,
830
+ "id": "ef6d6b62-de0b-4bc6-a6eb-847ab8e99da5",
831
+ "metadata": {},
832
+ "outputs": [
833
+ {
834
+ "ename": "ModuleNotFoundError",
835
+ "evalue": "No module named 'sentence_transformers'",
836
+ "output_type": "error",
837
+ "traceback": [
838
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
839
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
840
+ "File \u001b[0;32m<timed exec>:1\u001b[0m\n",
841
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sentence_transformers'"
842
+ ]
843
+ }
844
+ ],
845
+ "source": [
846
+ "%%time\n",
847
+ "from sentence_transformers import SentenceTransformer\n",
848
+ "\n",
849
+ "encoder= SentenceTransformer(\"all-mpnet-base-v2\")\n",
850
+ "vectors= encoder.encode(df2.desease_condition)\n",
851
+ "vectors.shape"
852
+ ]
853
+ },
854
+ {
855
+ "cell_type": "code",
856
+ "execution_count": 66,
857
+ "id": "7966d754-56d7-4555-a6c6-6a13772fb000",
858
+ "metadata": {},
859
+ "outputs": [
860
+ {
861
+ "name": "stdout",
862
+ "output_type": "stream",
863
+ "text": [
864
+ "CPU times: total: 62.5 ms\n",
865
+ "Wall time: 24.7 s\n"
866
+ ]
867
+ },
868
+ {
869
+ "name": "stderr",
870
+ "output_type": "stream",
871
+ "text": [
872
+ "<timed exec>:1: SettingWithCopyWarning: \n",
873
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
874
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
875
+ "\n",
876
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n"
877
+ ]
878
+ },
879
+ {
880
+ "data": {
881
+ "text/html": [
882
+ "<div>\n",
883
+ "<style scoped>\n",
884
+ " .dataframe tbody tr th:only-of-type {\n",
885
+ " vertical-align: middle;\n",
886
+ " }\n",
887
+ "\n",
888
+ " .dataframe tbody tr th {\n",
889
+ " vertical-align: top;\n",
890
+ " }\n",
891
+ "\n",
892
+ " .dataframe thead th {\n",
893
+ " text-align: right;\n",
894
+ " }\n",
895
+ "</style>\n",
896
+ "<table border=\"1\" class=\"dataframe\">\n",
897
+ " <thead>\n",
898
+ " <tr style=\"text-align: right;\">\n",
899
+ " <th></th>\n",
900
+ " <th>desease_condition</th>\n",
901
+ " <th>embeddings</th>\n",
902
+ " </tr>\n",
903
+ " </thead>\n",
904
+ " <tbody>\n",
905
+ " <tr>\n",
906
+ " <th>0</th>\n",
907
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
908
+ " <td>[-0.05811865255236626, -0.023393018171191216, ...</td>\n",
909
+ " </tr>\n",
910
+ " <tr>\n",
911
+ " <th>1</th>\n",
912
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
913
+ " <td>[-0.05811865255236626, -0.023393018171191216, ...</td>\n",
914
+ " </tr>\n",
915
+ " <tr>\n",
916
+ " <th>2</th>\n",
917
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
918
+ " <td>[-0.03460180386900902, -0.084668830037117, 0.2...</td>\n",
919
+ " </tr>\n",
920
+ " <tr>\n",
921
+ " <th>3</th>\n",
922
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
923
+ " <td>[-0.08236236125230789, -0.1235777735710144, 0....</td>\n",
924
+ " </tr>\n",
925
+ " <tr>\n",
926
+ " <th>4</th>\n",
927
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
928
+ " <td>[-0.1227850392460823, 0.07155642658472061, 0.1...</td>\n",
929
+ " </tr>\n",
930
+ " </tbody>\n",
931
+ "</table>\n",
932
+ "</div>"
933
+ ],
934
+ "text/plain": [
935
+ " desease_condition \\\n",
936
+ "0 marijuana abuse, substance-related disorders, ... \n",
937
+ "1 marijuana abuse, substance-related disorders, ... \n",
938
+ "2 tuberculosis, latent tuberculosis, infections,... \n",
939
+ "3 heart failure, heart diseases, cardiovascular ... \n",
940
+ "4 lymphoma, neoplasms by histologic type, neopla... \n",
941
+ "\n",
942
+ " embeddings \n",
943
+ "0 [-0.05811865255236626, -0.023393018171191216, ... \n",
944
+ "1 [-0.05811865255236626, -0.023393018171191216, ... \n",
945
+ "2 [-0.03460180386900902, -0.084668830037117, 0.2... \n",
946
+ "3 [-0.08236236125230789, -0.1235777735710144, 0.... \n",
947
+ "4 [-0.1227850392460823, 0.07155642658472061, 0.1... "
948
+ ]
949
+ },
950
+ "execution_count": 66,
951
+ "metadata": {},
952
+ "output_type": "execute_result"
953
+ }
954
+ ],
955
+ "source": [
956
+ "%%time\n",
957
+ "df2['embeddings']= df2['desease_condition'].apply(get_embeddings)\n",
958
+ "df2.head()"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "code",
963
+ "execution_count": 65,
964
+ "id": "2711980a-d1c0-441e-ae9a-531500b7b7cd",
965
+ "metadata": {},
966
+ "outputs": [
967
+ {
968
+ "data": {
969
+ "text/plain": [
970
+ "(128,)"
971
+ ]
972
+ },
973
+ "execution_count": 65,
974
+ "metadata": {},
975
+ "output_type": "execute_result"
976
+ }
977
+ ],
978
+ "source": [
979
+ "import numpy as np\n",
980
+ "np.array(df2.iloc[0].embeddings).shape"
981
+ ]
982
+ },
983
+ {
984
+ "cell_type": "code",
985
+ "execution_count": 23,
986
+ "id": "13e40184-999d-4363-be44-18ba9fb6745a",
987
+ "metadata": {},
988
+ "outputs": [
989
+ {
990
+ "name": "stderr",
991
+ "output_type": "stream",
992
+ "text": [
993
+ "C:\\Users\\ariji\\Anaconda3\\envs\\python310\\lib\\site-packages\\huggingface_hub\\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
994
+ " warnings.warn(\n"
995
+ ]
996
+ },
997
+ {
998
+ "ename": "KeyboardInterrupt",
999
+ "evalue": "",
1000
+ "output_type": "error",
1001
+ "traceback": [
1002
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
1003
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1004
+ "Input \u001b[1;32mIn [23]\u001b[0m, in \u001b[0;36m<cell line: 4>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msentence_transformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SentenceTransformer\n\u001b[0;32m 3\u001b[0m encoder\u001b[38;5;241m=\u001b[39m SentenceTransformer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mall-mpnet-base-v2\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 4\u001b[0m vectors\u001b[38;5;241m=\u001b[39m \u001b[43mencoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf_trials_filtered\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdesease_condition\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m vectors\u001b[38;5;241m.\u001b[39mshape\n",
1005
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:371\u001b[0m, in \u001b[0;36mSentenceTransformer.encode\u001b[1;34m(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings)\u001b[0m\n\u001b[0;32m 368\u001b[0m features\u001b[38;5;241m.\u001b[39mupdate(extra_features)\n\u001b[0;32m 370\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m--> 371\u001b[0m out_features \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 372\u001b[0m out_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence_embedding\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m truncate_embeddings(\n\u001b[0;32m 373\u001b[0m out_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence_embedding\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtruncate_dim\n\u001b[0;32m 374\u001b[0m )\n\u001b[0;32m 376\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_value \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_embeddings\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
1006
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 138\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 139\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 140\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
1007
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1008
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\sentence_transformers\\models\\Transformer.py:98\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, features)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m features:\n\u001b[0;32m 96\u001b[0m trans_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m---> 98\u001b[0m output_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_model(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mtrans_features, return_dict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 99\u001b[0m output_tokens \u001b[38;5;241m=\u001b[39m output_states[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 101\u001b[0m features\u001b[38;5;241m.\u001b[39mupdate({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_embeddings\u001b[39m\u001b[38;5;124m\"\u001b[39m: output_tokens, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m: features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]})\n",
1009
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1010
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:548\u001b[0m, in \u001b[0;36mMPNetModel.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m 546\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[0;32m 547\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(input_ids\u001b[38;5;241m=\u001b[39minput_ids, position_ids\u001b[38;5;241m=\u001b[39mposition_ids, inputs_embeds\u001b[38;5;241m=\u001b[39minputs_embeds)\n\u001b[1;32m--> 548\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 549\u001b[0m \u001b[43m \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 550\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 551\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 552\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 553\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 554\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 555\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 556\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 557\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
1011
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1012
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:338\u001b[0m, in \u001b[0;36mMPNetEncoder.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[0;32m 336\u001b[0m all_hidden_states \u001b[38;5;241m=\u001b[39m all_hidden_states \u001b[38;5;241m+\u001b[39m (hidden_states,)\n\u001b[1;32m--> 338\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m layer_module(\n\u001b[0;32m 339\u001b[0m hidden_states,\n\u001b[0;32m 340\u001b[0m attention_mask,\n\u001b[0;32m 341\u001b[0m head_mask[i],\n\u001b[0;32m 342\u001b[0m position_bias,\n\u001b[0;32m 343\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[0;32m 344\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[0;32m 345\u001b[0m )\n\u001b[0;32m 346\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 348\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
1013
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1014
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:308\u001b[0m, in \u001b[0;36mMPNetLayer.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, position_bias, output_attentions, **kwargs)\u001b[0m\n\u001b[0;32m 305\u001b[0m outputs \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m1\u001b[39m:] \u001b[38;5;66;03m# add self attentions if we output attention weights\u001b[39;00m\n\u001b[0;32m 307\u001b[0m intermediate_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mintermediate(attention_output)\n\u001b[1;32m--> 308\u001b[0m layer_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\u001b[43m(\u001b[49m\u001b[43mintermediate_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_output\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 309\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (layer_output,) \u001b[38;5;241m+\u001b[39m outputs\n\u001b[0;32m 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
1015
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1016
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:275\u001b[0m, in \u001b[0;36mMPNetOutput.forward\u001b[1;34m(self, hidden_states, input_tensor)\u001b[0m\n\u001b[0;32m 274\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor, input_tensor: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m--> 275\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdense\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 276\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(hidden_states)\n\u001b[0;32m 277\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLayerNorm(hidden_states \u001b[38;5;241m+\u001b[39m input_tensor)\n",
1017
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1018
+ "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
1019
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
1020
+ ]
1021
+ }
1022
+ ],
1023
+ "source": [
1024
+ "from sentence_transformers import SentenceTransformer\n",
1025
+ "\n",
1026
+ "encoder= SentenceTransformer(\"all-mpnet-base-v2\")\n",
1027
+ "vectors= encoder.encode(df2.desease_condition)\n",
1028
+ "vectors.shape"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "cell_type": "code",
1033
+ "execution_count": null,
1034
+ "id": "5f626382-75f1-4864-8574-bbcd183b926a",
1035
+ "metadata": {},
1036
+ "outputs": [],
1037
+ "source": []
1038
+ },
1039
+ {
1040
+ "cell_type": "code",
1041
+ "execution_count": null,
1042
+ "id": "4e064857-2480-4ac2-911a-1c5a2ffb62d6",
1043
+ "metadata": {},
1044
+ "outputs": [],
1045
+ "source": []
1046
+ },
1047
+ {
1048
+ "cell_type": "code",
1049
+ "execution_count": 102,
1050
+ "id": "c0a719a6-ee06-4f6d-a64b-8fc74c7afbca",
1051
+ "metadata": {},
1052
+ "outputs": [
1053
+ {
1054
+ "name": "stdout",
1055
+ "output_type": "stream",
1056
+ "text": [
1057
+ "[[ 0.01245328 0.07695839 0.01802037 ... 0.0093326 -0.03474615\n",
1058
+ " -0.02757339]\n",
1059
+ " [ 0.01846956 0.08282936 0.01921537 ... 0.01068991 -0.03402989\n",
1060
+ " -0.02075185]\n",
1061
+ " [ 0.01822413 -0.05593034 -0.00288358 ... 0.03525289 -0.05427228\n",
1062
+ " -0.03371295]\n",
1063
+ " ...\n",
1064
+ " [ 0.04031958 0.0040107 0.02156032 ... 0.01568209 -0.04320977\n",
1065
+ " -0.02990234]\n",
1066
+ " [ 0.0399131 0.00027251 0.02207735 ... 0.01440835 -0.04246744\n",
1067
+ " -0.02869584]\n",
1068
+ " [ 0.03773859 -0.00315346 0.0207725 ... 0.01205995 -0.04628598\n",
1069
+ " -0.02870333]]\n"
1070
+ ]
1071
+ }
1072
+ ],
1073
+ "source": [
1074
+ "print(vectors)"
1075
+ ]
1076
+ },
1077
+ {
1078
+ "cell_type": "code",
1079
+ "execution_count": 103,
1080
+ "id": "f21c08d5-596a-4fc4-ba72-2de8624e6c50",
1081
+ "metadata": {},
1082
+ "outputs": [
1083
+ {
1084
+ "data": {
1085
+ "text/plain": [
1086
+ "768"
1087
+ ]
1088
+ },
1089
+ "execution_count": 103,
1090
+ "metadata": {},
1091
+ "output_type": "execute_result"
1092
+ }
1093
+ ],
1094
+ "source": [
1095
+ "dim= vectors.shape[1]\n",
1096
+ "dim "
1097
+ ]
1098
+ },
1099
+ {
1100
+ "cell_type": "code",
1101
+ "execution_count": 106,
1102
+ "id": "8ed985f4-9402-431f-bfba-1236ba16b895",
1103
+ "metadata": {},
1104
+ "outputs": [
1105
+ {
1106
+ "data": {
1107
+ "text/plain": [
1108
+ "<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x0000023493347480> >"
1109
+ ]
1110
+ },
1111
+ "execution_count": 106,
1112
+ "metadata": {},
1113
+ "output_type": "execute_result"
1114
+ }
1115
+ ],
1116
+ "source": [
1117
+ "import faiss\n",
1118
+ "\n",
1119
+ "index= faiss.IndexFlatL2(dim)\n",
1120
+ "index"
1121
+ ]
1122
+ },
1123
+ {
1124
+ "cell_type": "code",
1125
+ "execution_count": 107,
1126
+ "id": "c0f1dde8-c173-45ba-957b-f4622fe6f0ee",
1127
+ "metadata": {},
1128
+ "outputs": [],
1129
+ "source": [
1130
+ "index.add(vectors)"
1131
+ ]
1132
+ },
1133
+ {
1134
+ "cell_type": "code",
1135
+ "execution_count": 108,
1136
+ "id": "71dad860-1f19-4166-a309-c9ce15f24792",
1137
+ "metadata": {
1138
+ "scrolled": true
1139
+ },
1140
+ "outputs": [
1141
+ {
1142
+ "name": "stdout",
1143
+ "output_type": "stream",
1144
+ "text": [
1145
+ "(768,)\n"
1146
+ ]
1147
+ },
1148
+ {
1149
+ "data": {
1150
+ "text/plain": [
1151
+ "array([-8.82369652e-03, 8.50650743e-02, 2.08267733e-03, 6.77651772e-03,\n",
1152
+ " -2.86661759e-02, -8.71188380e-03, 6.99447095e-02, 5.04214764e-02,\n",
1153
+ " 3.58386151e-02, 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n",
1154
+ " 2.27009598e-03, 2.10810862e-02, 2.66138893e-02, 1.90623086e-02,\n",
1155
+ " 4.44708914e-02, 2.96202525e-02, 5.42085357e-02, -2.34859088e-03,\n",
1156
+ " -9.87798795e-02, -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n",
1157
+ " 5.31156994e-02, -1.37044629e-02, 2.92537250e-02, -2.61334293e-02,\n",
1158
+ " -1.21854078e-04, -2.36813519e-02, -3.81283499e-02, -1.79494768e-02,\n",
1159
+ " -6.29265187e-03, 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n",
1160
+ " -1.28973828e-04, 4.01791967e-02, 4.21229303e-02, -8.72302521e-03,\n",
1161
+ " 7.44823692e-03, 7.68032745e-02, 6.50246907e-03, 3.40298638e-02,\n",
1162
+ " -1.80711355e-02, -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n",
1163
+ " -3.34868580e-02, 1.05205458e-02, 2.08975170e-02, 4.36686277e-02,\n",
1164
+ " 3.47612537e-02, -4.99080680e-02, 4.44446988e-02, 5.57280704e-03,\n",
1165
+ " -2.31200755e-02, -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n",
1166
+ " 4.67903316e-02, 1.91651955e-02, -5.12171052e-02, 2.46807020e-02,\n",
1167
+ " -5.52081019e-02, 3.50596346e-02, -7.01438356e-03, -3.36519890e-02,\n",
1168
+ " -1.41502097e-02, -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n",
1169
+ " -1.30138136e-02, 5.91567121e-02, 3.37168351e-02, 3.01467292e-02,\n",
1170
+ " -4.59552221e-02, 1.37365120e-03, 1.00179566e-02, 6.98126853e-04,\n",
1171
+ " 3.58139984e-02, 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n",
1172
+ " 4.75908853e-02, 5.48331346e-03, -6.41460950e-03, -1.23906611e-02,\n",
1173
+ " 5.82688041e-02, -1.60842277e-02, -2.95833423e-04, 6.97355811e-03,\n",
1174
+ " 2.48331465e-02, -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n",
1175
+ " 1.52637456e-02, 7.01832073e-03, 5.50601333e-02, 4.35096538e-03,\n",
1176
+ " 2.36319732e-02, -1.38118947e-02, -7.24233836e-02, 9.39742289e-03,\n",
1177
+ " -2.66901590e-02, 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n",
1178
+ " 3.08373477e-03, 7.12765753e-02, 7.13613164e-03, 3.62721197e-02,\n",
1179
+ " -4.53250594e-02, 2.54001115e-02, -2.54253373e-02, -1.23151275e-03,\n",
1180
+ " -1.34750446e-02, -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n",
1181
+ " -7.31003610e-03, 2.65329964e-02, -2.79857730e-03, 4.20840643e-02,\n",
1182
+ " 3.20205763e-02, -1.19518824e-02, -5.77116087e-02, 9.88688134e-03,\n",
1183
+ " 1.86814573e-02, -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n",
1184
+ " -3.33383717e-02, -5.52933291e-02, 5.64148054e-02, 4.92153503e-02,\n",
1185
+ " 3.33690383e-02, -3.92963700e-02, -6.91099390e-02, 3.79911740e-03,\n",
1186
+ " -1.74410697e-02, -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n",
1187
+ " 2.61192098e-02, -2.74193864e-02, 6.92490395e-03, -4.64810384e-03,\n",
1188
+ " -8.99862905e-04, 1.02159111e-02, -4.81114909e-02, 1.22787328e-02,\n",
1189
+ " -9.32844076e-03, -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n",
1190
+ " -1.60810221e-02, -2.20200215e-02, 2.32051890e-02, -5.07331975e-02,\n",
1191
+ " -1.01248249e-02, 5.62567115e-02, -2.60966737e-03, 9.27545596e-03,\n",
1192
+ " 5.32410555e-02, 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n",
1193
+ " -2.12969314e-02, 9.82244611e-02, -2.47648880e-02, 7.06253499e-02,\n",
1194
+ " 8.71159416e-03, -2.73140483e-02, 5.59884915e-03, -2.14829091e-02,\n",
1195
+ " -6.67077005e-02, 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n",
1196
+ " 3.77488993e-02, -1.37352264e-02, -2.85069812e-02, 1.81708820e-02,\n",
1197
+ " -4.07746173e-02, -4.71230270e-03, -1.59605164e-02, -1.25815195e-03,\n",
1198
+ " -6.59954594e-03, -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n",
1199
+ " 3.07933297e-02, 1.27626080e-02, -4.34489138e-02, 9.91576444e-03,\n",
1200
+ " 1.25470785e-02, -8.67356583e-02, 1.26097840e-03, 3.24709825e-02,\n",
1201
+ " -6.92409948e-02, -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n",
1202
+ " -2.35359464e-02, -2.95022167e-02, 2.88009271e-02, -3.26618887e-02,\n",
1203
+ " 7.09307985e-03, 3.09435464e-03, -5.09097055e-02, 3.54242921e-02,\n",
1204
+ " 5.37336655e-02, 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n",
1205
+ " 2.93767708e-03, 2.27999203e-02, 1.02668423e-02, 3.35033536e-02,\n",
1206
+ " -8.28316063e-02, -4.17127199e-02, -1.23034064e-02, 2.38543525e-02,\n",
1207
+ " -3.72257493e-02, 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n",
1208
+ " 5.74519299e-03, 2.89844945e-02, -2.21337453e-02, 2.34603398e-02,\n",
1209
+ " 6.33142609e-03, -2.24104542e-02, 1.47326495e-02, 1.98041964e-02,\n",
1210
+ " 3.05697713e-02, -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n",
1211
+ " 3.88860740e-02, -3.97440195e-02, -4.70216498e-02, 1.02172708e-02,\n",
1212
+ " -3.37972888e-03, -8.54947045e-03, 4.81354557e-02, 4.99849804e-02,\n",
1213
+ " 7.11378129e-03, -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n",
1214
+ " 5.24284095e-02, 2.16708388e-02, -4.00698110e-02, 5.15380092e-02,\n",
1215
+ " -6.03203699e-02, -6.50304696e-03, -1.03860423e-02, -7.47132823e-02,\n",
1216
+ " 3.59848235e-03, -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n",
1217
+ " -2.88047381e-02, -2.81904116e-02, 1.52729014e-02, -1.55570190e-02,\n",
1218
+ " 1.34619148e-02, 2.34364290e-02, 3.10326237e-02, -4.70464528e-02,\n",
1219
+ " -2.43550166e-02, -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n",
1220
+ " -5.30204549e-03, 5.52049950e-02, 4.50828709e-02, -7.30262510e-03,\n",
1221
+ " 5.56289777e-02, -9.46066808e-03, -3.37345451e-02, -1.87659152e-02,\n",
1222
+ " 3.57284099e-02, 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n",
1223
+ " 2.96422077e-04, 4.22447585e-02, 4.97253910e-02, 6.03130311e-02,\n",
1224
+ " 1.32281650e-02, 2.35939436e-02, -1.59284715e-02, 4.46444489e-02,\n",
1225
+ " -1.68315917e-02, 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n",
1226
+ " 8.99049267e-03, 4.74606343e-02, 6.70041004e-03, -1.15184486e-03,\n",
1227
+ " 2.69540539e-03, -2.77549177e-02, -1.33260442e-02, 2.60788556e-02,\n",
1228
+ " 4.35438640e-02, -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n",
1229
+ " 2.93240137e-02, 1.82274636e-03, -1.40310880e-02, -1.91633645e-02,\n",
1230
+ " 1.18790809e-02, -4.65121269e-02, -4.19883654e-02, -2.69681774e-02,\n",
1231
+ " -3.23035605e-02, -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n",
1232
+ " -2.55833156e-02, -5.73152229e-02, 3.30126472e-02, -7.90146552e-03,\n",
1233
+ " -1.08651863e-02, 1.10474667e-02, 3.03509296e-03, 1.55274626e-02,\n",
1234
+ " 1.05599947e-02, -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n",
1235
+ " 3.77239436e-02, 9.44003314e-02, -4.80610691e-02, 4.73537892e-02,\n",
1236
+ " 3.40655483e-02, 7.88806472e-03, -2.84915343e-02, 7.96849206e-02,\n",
1237
+ " 1.57442074e-02, -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n",
1238
+ " -1.72730908e-01, -8.72075930e-02, 2.86346450e-02, 2.16962174e-02,\n",
1239
+ " -4.80199270e-02, 6.49317261e-03, 1.67240556e-02, -2.56227311e-02,\n",
1240
+ " 2.19670162e-02, -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n",
1241
+ " -2.89566331e-02, 1.19498251e-02, -2.33849231e-02, -2.69133616e-02,\n",
1242
+ " -1.46602485e-02, 1.18886270e-02, 1.64973717e-02, -3.90495770e-02,\n",
1243
+ " -3.45575088e-03, 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n",
1244
+ " 2.10017413e-02, 2.74998210e-02, 3.03551817e-04, -1.15796946e-01,\n",
1245
+ " -4.66962112e-03, -4.80118394e-02, -3.55160870e-02, -4.72528581e-03,\n",
1246
+ " -4.29739058e-02, -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n",
1247
+ " 1.98413953e-02, -7.27679394e-03, 2.27117930e-02, -2.59338003e-02,\n",
1248
+ " 4.31442596e-02, 1.07885078e-02, -2.47129947e-02, -4.14506458e-02,\n",
1249
+ " 4.40958813e-02, 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n",
1250
+ " 1.13289580e-02, -5.57265691e-02, 1.71151303e-03, -1.24145029e-02,\n",
1251
+ " -3.57853901e-03, -4.86295968e-02, -5.14956787e-02, 4.79425713e-02,\n",
1252
+ " -3.24050151e-02, 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n",
1253
+ " 8.20766483e-03, -6.27530292e-02, -1.30661400e-02, -3.52081768e-02,\n",
1254
+ " 4.83807474e-02, 9.81860235e-03, 1.14539362e-01, -1.88471414e-02,\n",
1255
+ " 6.07751869e-02, -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n",
1256
+ " 2.64345529e-03, 3.07400171e-02, -4.31060083e-02, -6.19985871e-02,\n",
1257
+ " 5.50477020e-03, 1.62547994e-02, -8.26352183e-03, 7.56437238e-03,\n",
1258
+ " -4.79784003e-03, 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n",
1259
+ " 1.41595434e-02, 5.31185642e-02, 6.78585656e-03, 6.56357184e-02,\n",
1260
+ " -5.06135784e-02, -3.05179805e-02, 7.06539825e-02, -3.55644710e-02,\n",
1261
+ " -4.92612133e-03, 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n",
1262
+ " -1.86746120e-02, 2.49281265e-02, -4.92450967e-03, 1.66887734e-02,\n",
1263
+ " 4.62210961e-02, 4.07794118e-02, 2.52511259e-02, -2.83305068e-02,\n",
1264
+ " -2.78001893e-02, -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n",
1265
+ " 1.09969089e-02, 1.69700030e-02, -8.59475043e-03, 4.70476560e-02,\n",
1266
+ " 3.64770554e-02, 2.09835749e-02, 1.01236468e-02, 2.75151283e-02,\n",
1267
+ " 4.33402918e-02, -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n",
1268
+ " -6.10819347e-02, -2.86280159e-02, 4.68054451e-02, 1.29892454e-02,\n",
1269
+ " -1.71940885e-02, -2.52429228e-02, 3.86423096e-02, -1.35919163e-02,\n",
1270
+ " -5.27431667e-02, 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n",
1271
+ " 3.23252901e-02, 5.03172688e-02, -4.45654802e-02, 2.90075876e-02,\n",
1272
+ " -1.35373492e-02, 6.78209821e-03, -5.89249916e-02, 4.28890549e-02,\n",
1273
+ " -2.36034058e-02, -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n",
1274
+ " 1.45543357e-02, 1.07806427e-02, -6.06855676e-02, -4.95252907e-02,\n",
1275
+ " 1.02004781e-02, 4.60227691e-02, -1.08090881e-02, 4.42408510e-02,\n",
1276
+ " 4.15152796e-02, 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n",
1277
+ " -2.70090066e-02, 2.68773828e-02, -1.97812133e-02, 2.25932393e-02,\n",
1278
+ " -1.33560598e-02, -1.50896851e-02, -3.14053567e-03, 1.54051669e-02,\n",
1279
+ " 1.86488125e-02, -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n",
1280
+ " -2.37891261e-04, 1.84722953e-02, 3.60381305e-02, -5.85213909e-03,\n",
1281
+ " 4.44293395e-02, -1.11264118e-03, -4.79441285e-02, 3.46464328e-02,\n",
1282
+ " -2.53370814e-02, -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n",
1283
+ " -4.38152434e-04, 4.08602282e-02, -2.29470823e-02, -1.89938806e-02,\n",
1284
+ " -1.52037974e-04, 1.05516789e-02, 2.08601039e-02, -6.98119551e-02,\n",
1285
+ " 3.66246551e-02, -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n",
1286
+ " 6.51817098e-02, 4.29646857e-02, 2.56071109e-02, -3.28080021e-02,\n",
1287
+ " 1.20534413e-02, 3.56224040e-03, -1.01593453e-02, -1.96505673e-04,\n",
1288
+ " 4.33485657e-02, -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n",
1289
+ " -1.40319867e-02, -3.63940969e-02, -3.09983976e-02, -4.19548260e-33,\n",
1290
+ " 7.11604580e-02, 4.78382297e-02, 1.89297704e-03, -1.60731785e-02,\n",
1291
+ " 2.53787991e-02, -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n",
1292
+ " 1.68679946e-03, 1.92391127e-02, -2.20667192e-04, 1.32907527e-02,\n",
1293
+ " 5.99487219e-03, 2.75156219e-02, -5.06000873e-03, -3.58465910e-02,\n",
1294
+ " 8.20948277e-03, -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n",
1295
+ " -1.09853260e-01, -3.66037302e-02, 3.55480015e-02, 4.23291475e-02,\n",
1296
+ " 1.48312682e-02, 5.68749309e-02, 3.57767567e-02, 1.40728084e-02,\n",
1297
+ " -4.00471613e-02, 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n",
1298
+ " 1.24238459e-02, 1.20237898e-02, -7.69484974e-03, -3.30727436e-02,\n",
1299
+ " -1.45808076e-02, 3.43246050e-02, 3.21143419e-02, -4.96741422e-02,\n",
1300
+ " -5.27968369e-02, 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n",
1301
+ " 2.77636852e-02, 5.90689071e-02, -2.61273161e-02, -6.95008039e-02,\n",
1302
+ " -3.15576978e-02, -5.62214339e-03, -7.93884136e-03, -3.62196900e-02,\n",
1303
+ " -8.26047733e-03, 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n",
1304
+ " -2.52235290e-02, -3.88054736e-02, -2.00710595e-02, 1.50789914e-03,\n",
1305
+ " -5.51338419e-02, -8.35673045e-03, -1.61523875e-02, -8.79513845e-02,\n",
1306
+ " -5.28004877e-02, -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n",
1307
+ " 4.44932319e-02, 8.69598426e-03, -1.14432694e-02, 4.47212979e-02,\n",
1308
+ " 2.70624813e-02, -3.86100151e-02, -3.07358261e-02, 2.75634117e-02,\n",
1309
+ " 1.48464069e-02, -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n",
1310
+ " 8.05836394e-02, -1.69498641e-02, 4.44465503e-02, -2.09145956e-02,\n",
1311
+ " -3.37407738e-02, 3.85780074e-02, -7.44559616e-02, 1.17512364e-02,\n",
1312
+ " 1.01964204e-02, -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n",
1313
+ " 2.54592765e-02, -1.46158040e-02, 5.46646416e-02, 1.43051194e-03,\n",
1314
+ " 2.99116820e-02, 2.24273186e-02, -5.79927117e-03, -1.33864526e-02,\n",
1315
+ " -2.52460372e-02, -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n",
1316
+ " 3.38429734e-02, -2.11539529e-02, 7.17787817e-02, -7.78904185e-02,\n",
1317
+ " -4.04084288e-02, 4.90567498e-02, -2.61603445e-02, 1.97753590e-02,\n",
1318
+ " 4.97209951e-02, -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n",
1319
+ " 2.68440694e-02, 3.29160057e-02, -8.24410375e-03, -1.33646047e-02,\n",
1320
+ " -6.22822754e-02, -1.13362661e-02, -3.79339382e-02, -6.56360280e-05,\n",
1321
+ " -1.08087100e-02, 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n",
1322
+ " -2.54666172e-02, -3.05371322e-02, -1.53249800e-02, -9.87035502e-03,\n",
1323
+ " 1.95337094e-07, -1.76476724e-02, 5.71432859e-02, -2.49180794e-02,\n",
1324
+ " 5.85253723e-02, 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n",
1325
+ " 4.07801419e-02, 4.13940698e-02, 2.55707726e-02, 2.18985360e-02,\n",
1326
+ " -3.04434425e-03, -3.77355106e-02, -6.24866784e-02, -1.17468778e-02,\n",
1327
+ " -4.82194684e-02, -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n",
1328
+ " -2.48471629e-02, 8.05181568e-04, -4.85844910e-03, -5.16015477e-03,\n",
1329
+ " 7.53483502e-03, -9.46175400e-03, -2.39896346e-02, -3.14654633e-02,\n",
1330
+ " 1.50111094e-02, -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n",
1331
+ " 3.08971256e-02, 1.72299352e-02, 5.93419448e-02, -5.74274361e-02,\n",
1332
+ " -8.16087723e-02, -4.80572283e-02, -2.68838424e-02, -1.96331330e-02,\n",
1333
+ " -9.15831141e-03, 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n",
1334
+ " 9.21937004e-02, 1.37132118e-02, -1.19096776e-02, -4.09874134e-02,\n",
1335
+ " 3.37628126e-02, -4.64820908e-03, -2.50304434e-02, 6.25852346e-02,\n",
1336
+ " -1.24449311e-02, 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n",
1337
+ " 5.08641489e-02, 2.53822445e-03, 5.25634140e-02, 1.14882430e-02,\n",
1338
+ " 5.01894541e-02, -3.55215147e-02, -3.31749097e-02, -3.02003417e-03,\n",
1339
+ " -5.36288768e-02, -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n",
1340
+ " 9.56887701e-35, 2.55127084e-02, -1.44770980e-04, 1.96710341e-02,\n",
1341
+ " -1.33620016e-02, -1.51910949e-02, -3.28495577e-02, -1.52465852e-03,\n",
1342
+ " -2.65272055e-02, -4.35708016e-02, -1.75950192e-02, -2.20594816e-02],\n",
1343
+ " dtype=float32)"
1344
+ ]
1345
+ },
1346
+ "execution_count": 108,
1347
+ "metadata": {},
1348
+ "output_type": "execute_result"
1349
+ }
1350
+ ],
1351
+ "source": [
1352
+ "search_query= \"clinical trials related to alzheimers\"\n",
1353
+ "vec= encoder.encode(search_query)\n",
1354
+ "print(vec.shape)\n",
1355
+ "vec"
1356
+ ]
1357
+ },
1358
+ {
1359
+ "cell_type": "code",
1360
+ "execution_count": 109,
1361
+ "id": "613fd415-4194-45e6-b9f3-9a7707845ad5",
1362
+ "metadata": {
1363
+ "scrolled": true
1364
+ },
1365
+ "outputs": [
1366
+ {
1367
+ "name": "stdout",
1368
+ "output_type": "stream",
1369
+ "text": [
1370
+ "(1, 768)\n"
1371
+ ]
1372
+ },
1373
+ {
1374
+ "data": {
1375
+ "text/plain": [
1376
+ "array([[-8.82369652e-03, 8.50650743e-02, 2.08267733e-03,\n",
1377
+ " 6.77651772e-03, -2.86661759e-02, -8.71188380e-03,\n",
1378
+ " 6.99447095e-02, 5.04214764e-02, 3.58386151e-02,\n",
1379
+ " 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n",
1380
+ " 2.27009598e-03, 2.10810862e-02, 2.66138893e-02,\n",
1381
+ " 1.90623086e-02, 4.44708914e-02, 2.96202525e-02,\n",
1382
+ " 5.42085357e-02, -2.34859088e-03, -9.87798795e-02,\n",
1383
+ " -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n",
1384
+ " 5.31156994e-02, -1.37044629e-02, 2.92537250e-02,\n",
1385
+ " -2.61334293e-02, -1.21854078e-04, -2.36813519e-02,\n",
1386
+ " -3.81283499e-02, -1.79494768e-02, -6.29265187e-03,\n",
1387
+ " 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n",
1388
+ " -1.28973828e-04, 4.01791967e-02, 4.21229303e-02,\n",
1389
+ " -8.72302521e-03, 7.44823692e-03, 7.68032745e-02,\n",
1390
+ " 6.50246907e-03, 3.40298638e-02, -1.80711355e-02,\n",
1391
+ " -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n",
1392
+ " -3.34868580e-02, 1.05205458e-02, 2.08975170e-02,\n",
1393
+ " 4.36686277e-02, 3.47612537e-02, -4.99080680e-02,\n",
1394
+ " 4.44446988e-02, 5.57280704e-03, -2.31200755e-02,\n",
1395
+ " -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n",
1396
+ " 4.67903316e-02, 1.91651955e-02, -5.12171052e-02,\n",
1397
+ " 2.46807020e-02, -5.52081019e-02, 3.50596346e-02,\n",
1398
+ " -7.01438356e-03, -3.36519890e-02, -1.41502097e-02,\n",
1399
+ " -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n",
1400
+ " -1.30138136e-02, 5.91567121e-02, 3.37168351e-02,\n",
1401
+ " 3.01467292e-02, -4.59552221e-02, 1.37365120e-03,\n",
1402
+ " 1.00179566e-02, 6.98126853e-04, 3.58139984e-02,\n",
1403
+ " 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n",
1404
+ " 4.75908853e-02, 5.48331346e-03, -6.41460950e-03,\n",
1405
+ " -1.23906611e-02, 5.82688041e-02, -1.60842277e-02,\n",
1406
+ " -2.95833423e-04, 6.97355811e-03, 2.48331465e-02,\n",
1407
+ " -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n",
1408
+ " 1.52637456e-02, 7.01832073e-03, 5.50601333e-02,\n",
1409
+ " 4.35096538e-03, 2.36319732e-02, -1.38118947e-02,\n",
1410
+ " -7.24233836e-02, 9.39742289e-03, -2.66901590e-02,\n",
1411
+ " 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n",
1412
+ " 3.08373477e-03, 7.12765753e-02, 7.13613164e-03,\n",
1413
+ " 3.62721197e-02, -4.53250594e-02, 2.54001115e-02,\n",
1414
+ " -2.54253373e-02, -1.23151275e-03, -1.34750446e-02,\n",
1415
+ " -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n",
1416
+ " -7.31003610e-03, 2.65329964e-02, -2.79857730e-03,\n",
1417
+ " 4.20840643e-02, 3.20205763e-02, -1.19518824e-02,\n",
1418
+ " -5.77116087e-02, 9.88688134e-03, 1.86814573e-02,\n",
1419
+ " -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n",
1420
+ " -3.33383717e-02, -5.52933291e-02, 5.64148054e-02,\n",
1421
+ " 4.92153503e-02, 3.33690383e-02, -3.92963700e-02,\n",
1422
+ " -6.91099390e-02, 3.79911740e-03, -1.74410697e-02,\n",
1423
+ " -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n",
1424
+ " 2.61192098e-02, -2.74193864e-02, 6.92490395e-03,\n",
1425
+ " -4.64810384e-03, -8.99862905e-04, 1.02159111e-02,\n",
1426
+ " -4.81114909e-02, 1.22787328e-02, -9.32844076e-03,\n",
1427
+ " -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n",
1428
+ " -1.60810221e-02, -2.20200215e-02, 2.32051890e-02,\n",
1429
+ " -5.07331975e-02, -1.01248249e-02, 5.62567115e-02,\n",
1430
+ " -2.60966737e-03, 9.27545596e-03, 5.32410555e-02,\n",
1431
+ " 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n",
1432
+ " -2.12969314e-02, 9.82244611e-02, -2.47648880e-02,\n",
1433
+ " 7.06253499e-02, 8.71159416e-03, -2.73140483e-02,\n",
1434
+ " 5.59884915e-03, -2.14829091e-02, -6.67077005e-02,\n",
1435
+ " 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n",
1436
+ " 3.77488993e-02, -1.37352264e-02, -2.85069812e-02,\n",
1437
+ " 1.81708820e-02, -4.07746173e-02, -4.71230270e-03,\n",
1438
+ " -1.59605164e-02, -1.25815195e-03, -6.59954594e-03,\n",
1439
+ " -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n",
1440
+ " 3.07933297e-02, 1.27626080e-02, -4.34489138e-02,\n",
1441
+ " 9.91576444e-03, 1.25470785e-02, -8.67356583e-02,\n",
1442
+ " 1.26097840e-03, 3.24709825e-02, -6.92409948e-02,\n",
1443
+ " -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n",
1444
+ " -2.35359464e-02, -2.95022167e-02, 2.88009271e-02,\n",
1445
+ " -3.26618887e-02, 7.09307985e-03, 3.09435464e-03,\n",
1446
+ " -5.09097055e-02, 3.54242921e-02, 5.37336655e-02,\n",
1447
+ " 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n",
1448
+ " 2.93767708e-03, 2.27999203e-02, 1.02668423e-02,\n",
1449
+ " 3.35033536e-02, -8.28316063e-02, -4.17127199e-02,\n",
1450
+ " -1.23034064e-02, 2.38543525e-02, -3.72257493e-02,\n",
1451
+ " 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n",
1452
+ " 5.74519299e-03, 2.89844945e-02, -2.21337453e-02,\n",
1453
+ " 2.34603398e-02, 6.33142609e-03, -2.24104542e-02,\n",
1454
+ " 1.47326495e-02, 1.98041964e-02, 3.05697713e-02,\n",
1455
+ " -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n",
1456
+ " 3.88860740e-02, -3.97440195e-02, -4.70216498e-02,\n",
1457
+ " 1.02172708e-02, -3.37972888e-03, -8.54947045e-03,\n",
1458
+ " 4.81354557e-02, 4.99849804e-02, 7.11378129e-03,\n",
1459
+ " -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n",
1460
+ " 5.24284095e-02, 2.16708388e-02, -4.00698110e-02,\n",
1461
+ " 5.15380092e-02, -6.03203699e-02, -6.50304696e-03,\n",
1462
+ " -1.03860423e-02, -7.47132823e-02, 3.59848235e-03,\n",
1463
+ " -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n",
1464
+ " -2.88047381e-02, -2.81904116e-02, 1.52729014e-02,\n",
1465
+ " -1.55570190e-02, 1.34619148e-02, 2.34364290e-02,\n",
1466
+ " 3.10326237e-02, -4.70464528e-02, -2.43550166e-02,\n",
1467
+ " -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n",
1468
+ " -5.30204549e-03, 5.52049950e-02, 4.50828709e-02,\n",
1469
+ " -7.30262510e-03, 5.56289777e-02, -9.46066808e-03,\n",
1470
+ " -3.37345451e-02, -1.87659152e-02, 3.57284099e-02,\n",
1471
+ " 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n",
1472
+ " 2.96422077e-04, 4.22447585e-02, 4.97253910e-02,\n",
1473
+ " 6.03130311e-02, 1.32281650e-02, 2.35939436e-02,\n",
1474
+ " -1.59284715e-02, 4.46444489e-02, -1.68315917e-02,\n",
1475
+ " 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n",
1476
+ " 8.99049267e-03, 4.74606343e-02, 6.70041004e-03,\n",
1477
+ " -1.15184486e-03, 2.69540539e-03, -2.77549177e-02,\n",
1478
+ " -1.33260442e-02, 2.60788556e-02, 4.35438640e-02,\n",
1479
+ " -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n",
1480
+ " 2.93240137e-02, 1.82274636e-03, -1.40310880e-02,\n",
1481
+ " -1.91633645e-02, 1.18790809e-02, -4.65121269e-02,\n",
1482
+ " -4.19883654e-02, -2.69681774e-02, -3.23035605e-02,\n",
1483
+ " -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n",
1484
+ " -2.55833156e-02, -5.73152229e-02, 3.30126472e-02,\n",
1485
+ " -7.90146552e-03, -1.08651863e-02, 1.10474667e-02,\n",
1486
+ " 3.03509296e-03, 1.55274626e-02, 1.05599947e-02,\n",
1487
+ " -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n",
1488
+ " 3.77239436e-02, 9.44003314e-02, -4.80610691e-02,\n",
1489
+ " 4.73537892e-02, 3.40655483e-02, 7.88806472e-03,\n",
1490
+ " -2.84915343e-02, 7.96849206e-02, 1.57442074e-02,\n",
1491
+ " -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n",
1492
+ " -1.72730908e-01, -8.72075930e-02, 2.86346450e-02,\n",
1493
+ " 2.16962174e-02, -4.80199270e-02, 6.49317261e-03,\n",
1494
+ " 1.67240556e-02, -2.56227311e-02, 2.19670162e-02,\n",
1495
+ " -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n",
1496
+ " -2.89566331e-02, 1.19498251e-02, -2.33849231e-02,\n",
1497
+ " -2.69133616e-02, -1.46602485e-02, 1.18886270e-02,\n",
1498
+ " 1.64973717e-02, -3.90495770e-02, -3.45575088e-03,\n",
1499
+ " 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n",
1500
+ " 2.10017413e-02, 2.74998210e-02, 3.03551817e-04,\n",
1501
+ " -1.15796946e-01, -4.66962112e-03, -4.80118394e-02,\n",
1502
+ " -3.55160870e-02, -4.72528581e-03, -4.29739058e-02,\n",
1503
+ " -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n",
1504
+ " 1.98413953e-02, -7.27679394e-03, 2.27117930e-02,\n",
1505
+ " -2.59338003e-02, 4.31442596e-02, 1.07885078e-02,\n",
1506
+ " -2.47129947e-02, -4.14506458e-02, 4.40958813e-02,\n",
1507
+ " 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n",
1508
+ " 1.13289580e-02, -5.57265691e-02, 1.71151303e-03,\n",
1509
+ " -1.24145029e-02, -3.57853901e-03, -4.86295968e-02,\n",
1510
+ " -5.14956787e-02, 4.79425713e-02, -3.24050151e-02,\n",
1511
+ " 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n",
1512
+ " 8.20766483e-03, -6.27530292e-02, -1.30661400e-02,\n",
1513
+ " -3.52081768e-02, 4.83807474e-02, 9.81860235e-03,\n",
1514
+ " 1.14539362e-01, -1.88471414e-02, 6.07751869e-02,\n",
1515
+ " -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n",
1516
+ " 2.64345529e-03, 3.07400171e-02, -4.31060083e-02,\n",
1517
+ " -6.19985871e-02, 5.50477020e-03, 1.62547994e-02,\n",
1518
+ " -8.26352183e-03, 7.56437238e-03, -4.79784003e-03,\n",
1519
+ " 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n",
1520
+ " 1.41595434e-02, 5.31185642e-02, 6.78585656e-03,\n",
1521
+ " 6.56357184e-02, -5.06135784e-02, -3.05179805e-02,\n",
1522
+ " 7.06539825e-02, -3.55644710e-02, -4.92612133e-03,\n",
1523
+ " 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n",
1524
+ " -1.86746120e-02, 2.49281265e-02, -4.92450967e-03,\n",
1525
+ " 1.66887734e-02, 4.62210961e-02, 4.07794118e-02,\n",
1526
+ " 2.52511259e-02, -2.83305068e-02, -2.78001893e-02,\n",
1527
+ " -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n",
1528
+ " 1.09969089e-02, 1.69700030e-02, -8.59475043e-03,\n",
1529
+ " 4.70476560e-02, 3.64770554e-02, 2.09835749e-02,\n",
1530
+ " 1.01236468e-02, 2.75151283e-02, 4.33402918e-02,\n",
1531
+ " -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n",
1532
+ " -6.10819347e-02, -2.86280159e-02, 4.68054451e-02,\n",
1533
+ " 1.29892454e-02, -1.71940885e-02, -2.52429228e-02,\n",
1534
+ " 3.86423096e-02, -1.35919163e-02, -5.27431667e-02,\n",
1535
+ " 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n",
1536
+ " 3.23252901e-02, 5.03172688e-02, -4.45654802e-02,\n",
1537
+ " 2.90075876e-02, -1.35373492e-02, 6.78209821e-03,\n",
1538
+ " -5.89249916e-02, 4.28890549e-02, -2.36034058e-02,\n",
1539
+ " -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n",
1540
+ " 1.45543357e-02, 1.07806427e-02, -6.06855676e-02,\n",
1541
+ " -4.95252907e-02, 1.02004781e-02, 4.60227691e-02,\n",
1542
+ " -1.08090881e-02, 4.42408510e-02, 4.15152796e-02,\n",
1543
+ " 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n",
1544
+ " -2.70090066e-02, 2.68773828e-02, -1.97812133e-02,\n",
1545
+ " 2.25932393e-02, -1.33560598e-02, -1.50896851e-02,\n",
1546
+ " -3.14053567e-03, 1.54051669e-02, 1.86488125e-02,\n",
1547
+ " -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n",
1548
+ " -2.37891261e-04, 1.84722953e-02, 3.60381305e-02,\n",
1549
+ " -5.85213909e-03, 4.44293395e-02, -1.11264118e-03,\n",
1550
+ " -4.79441285e-02, 3.46464328e-02, -2.53370814e-02,\n",
1551
+ " -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n",
1552
+ " -4.38152434e-04, 4.08602282e-02, -2.29470823e-02,\n",
1553
+ " -1.89938806e-02, -1.52037974e-04, 1.05516789e-02,\n",
1554
+ " 2.08601039e-02, -6.98119551e-02, 3.66246551e-02,\n",
1555
+ " -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n",
1556
+ " 6.51817098e-02, 4.29646857e-02, 2.56071109e-02,\n",
1557
+ " -3.28080021e-02, 1.20534413e-02, 3.56224040e-03,\n",
1558
+ " -1.01593453e-02, -1.96505673e-04, 4.33485657e-02,\n",
1559
+ " -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n",
1560
+ " -1.40319867e-02, -3.63940969e-02, -3.09983976e-02,\n",
1561
+ " -4.19548260e-33, 7.11604580e-02, 4.78382297e-02,\n",
1562
+ " 1.89297704e-03, -1.60731785e-02, 2.53787991e-02,\n",
1563
+ " -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n",
1564
+ " 1.68679946e-03, 1.92391127e-02, -2.20667192e-04,\n",
1565
+ " 1.32907527e-02, 5.99487219e-03, 2.75156219e-02,\n",
1566
+ " -5.06000873e-03, -3.58465910e-02, 8.20948277e-03,\n",
1567
+ " -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n",
1568
+ " -1.09853260e-01, -3.66037302e-02, 3.55480015e-02,\n",
1569
+ " 4.23291475e-02, 1.48312682e-02, 5.68749309e-02,\n",
1570
+ " 3.57767567e-02, 1.40728084e-02, -4.00471613e-02,\n",
1571
+ " 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n",
1572
+ " 1.24238459e-02, 1.20237898e-02, -7.69484974e-03,\n",
1573
+ " -3.30727436e-02, -1.45808076e-02, 3.43246050e-02,\n",
1574
+ " 3.21143419e-02, -4.96741422e-02, -5.27968369e-02,\n",
1575
+ " 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n",
1576
+ " 2.77636852e-02, 5.90689071e-02, -2.61273161e-02,\n",
1577
+ " -6.95008039e-02, -3.15576978e-02, -5.62214339e-03,\n",
1578
+ " -7.93884136e-03, -3.62196900e-02, -8.26047733e-03,\n",
1579
+ " 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n",
1580
+ " -2.52235290e-02, -3.88054736e-02, -2.00710595e-02,\n",
1581
+ " 1.50789914e-03, -5.51338419e-02, -8.35673045e-03,\n",
1582
+ " -1.61523875e-02, -8.79513845e-02, -5.28004877e-02,\n",
1583
+ " -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n",
1584
+ " 4.44932319e-02, 8.69598426e-03, -1.14432694e-02,\n",
1585
+ " 4.47212979e-02, 2.70624813e-02, -3.86100151e-02,\n",
1586
+ " -3.07358261e-02, 2.75634117e-02, 1.48464069e-02,\n",
1587
+ " -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n",
1588
+ " 8.05836394e-02, -1.69498641e-02, 4.44465503e-02,\n",
1589
+ " -2.09145956e-02, -3.37407738e-02, 3.85780074e-02,\n",
1590
+ " -7.44559616e-02, 1.17512364e-02, 1.01964204e-02,\n",
1591
+ " -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n",
1592
+ " 2.54592765e-02, -1.46158040e-02, 5.46646416e-02,\n",
1593
+ " 1.43051194e-03, 2.99116820e-02, 2.24273186e-02,\n",
1594
+ " -5.79927117e-03, -1.33864526e-02, -2.52460372e-02,\n",
1595
+ " -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n",
1596
+ " 3.38429734e-02, -2.11539529e-02, 7.17787817e-02,\n",
1597
+ " -7.78904185e-02, -4.04084288e-02, 4.90567498e-02,\n",
1598
+ " -2.61603445e-02, 1.97753590e-02, 4.97209951e-02,\n",
1599
+ " -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n",
1600
+ " 2.68440694e-02, 3.29160057e-02, -8.24410375e-03,\n",
1601
+ " -1.33646047e-02, -6.22822754e-02, -1.13362661e-02,\n",
1602
+ " -3.79339382e-02, -6.56360280e-05, -1.08087100e-02,\n",
1603
+ " 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n",
1604
+ " -2.54666172e-02, -3.05371322e-02, -1.53249800e-02,\n",
1605
+ " -9.87035502e-03, 1.95337094e-07, -1.76476724e-02,\n",
1606
+ " 5.71432859e-02, -2.49180794e-02, 5.85253723e-02,\n",
1607
+ " 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n",
1608
+ " 4.07801419e-02, 4.13940698e-02, 2.55707726e-02,\n",
1609
+ " 2.18985360e-02, -3.04434425e-03, -3.77355106e-02,\n",
1610
+ " -6.24866784e-02, -1.17468778e-02, -4.82194684e-02,\n",
1611
+ " -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n",
1612
+ " -2.48471629e-02, 8.05181568e-04, -4.85844910e-03,\n",
1613
+ " -5.16015477e-03, 7.53483502e-03, -9.46175400e-03,\n",
1614
+ " -2.39896346e-02, -3.14654633e-02, 1.50111094e-02,\n",
1615
+ " -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n",
1616
+ " 3.08971256e-02, 1.72299352e-02, 5.93419448e-02,\n",
1617
+ " -5.74274361e-02, -8.16087723e-02, -4.80572283e-02,\n",
1618
+ " -2.68838424e-02, -1.96331330e-02, -9.15831141e-03,\n",
1619
+ " 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n",
1620
+ " 9.21937004e-02, 1.37132118e-02, -1.19096776e-02,\n",
1621
+ " -4.09874134e-02, 3.37628126e-02, -4.64820908e-03,\n",
1622
+ " -2.50304434e-02, 6.25852346e-02, -1.24449311e-02,\n",
1623
+ " 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n",
1624
+ " 5.08641489e-02, 2.53822445e-03, 5.25634140e-02,\n",
1625
+ " 1.14882430e-02, 5.01894541e-02, -3.55215147e-02,\n",
1626
+ " -3.31749097e-02, -3.02003417e-03, -5.36288768e-02,\n",
1627
+ " -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n",
1628
+ " 9.56887701e-35, 2.55127084e-02, -1.44770980e-04,\n",
1629
+ " 1.96710341e-02, -1.33620016e-02, -1.51910949e-02,\n",
1630
+ " -3.28495577e-02, -1.52465852e-03, -2.65272055e-02,\n",
1631
+ " -4.35708016e-02, -1.75950192e-02, -2.20594816e-02]], dtype=float32)"
1632
+ ]
1633
+ },
1634
+ "execution_count": 109,
1635
+ "metadata": {},
1636
+ "output_type": "execute_result"
1637
+ }
1638
+ ],
1639
+ "source": [
1640
+ "import numpy as np\n",
1641
+ "svec= np.array(vec).reshape(1,-1)\n",
1642
+ "print(svec.shape)\n",
1643
+ "svec"
1644
+ ]
1645
+ },
1646
+ {
1647
+ "cell_type": "code",
1648
+ "execution_count": 110,
1649
+ "id": "fef30d70-6958-4259-abb6-09f8c1870a2b",
1650
+ "metadata": {},
1651
+ "outputs": [
1652
+ {
1653
+ "name": "stdout",
1654
+ "output_type": "stream",
1655
+ "text": [
1656
+ "[[0.7731663 0.79433584]] [[330 331]]\n"
1657
+ ]
1658
+ }
1659
+ ],
1660
+ "source": [
1661
+ "distances, I= index.search(svec, k=2)\n",
1662
+ "print(distances, I)"
1663
+ ]
1664
+ },
1665
+ {
1666
+ "cell_type": "code",
1667
+ "execution_count": 111,
1668
+ "id": "eb00598c-9799-4697-b2a3-356bb5aae0f1",
1669
+ "metadata": {},
1670
+ "outputs": [
1671
+ {
1672
+ "data": {
1673
+ "text/html": [
1674
+ "<div>\n",
1675
+ "<style scoped>\n",
1676
+ " .dataframe tbody tr th:only-of-type {\n",
1677
+ " vertical-align: middle;\n",
1678
+ " }\n",
1679
+ "\n",
1680
+ " .dataframe tbody tr th {\n",
1681
+ " vertical-align: top;\n",
1682
+ " }\n",
1683
+ "\n",
1684
+ " .dataframe thead th {\n",
1685
+ " text-align: right;\n",
1686
+ " }\n",
1687
+ "</style>\n",
1688
+ "<table border=\"1\" class=\"dataframe\">\n",
1689
+ " <thead>\n",
1690
+ " <tr style=\"text-align: right;\">\n",
1691
+ " <th></th>\n",
1692
+ " <th>desease_condition</th>\n",
1693
+ " <th>text</th>\n",
1694
+ " </tr>\n",
1695
+ " </thead>\n",
1696
+ " <tbody>\n",
1697
+ " <tr>\n",
1698
+ " <th>330</th>\n",
1699
+ " <td>['alzheimer disease', 'dementia', 'brain disea...</td>\n",
1700
+ " <td>nct_id: NCT02164643\\nsummary: A Multicenter na...</td>\n",
1701
+ " </tr>\n",
1702
+ " <tr>\n",
1703
+ " <th>331</th>\n",
1704
+ " <td>['alzheimer disease', 'dementia', 'brain disea...</td>\n",
1705
+ " <td>nct_id: NCT02164643\\nsummary: A Multicenter na...</td>\n",
1706
+ " </tr>\n",
1707
+ " </tbody>\n",
1708
+ "</table>\n",
1709
+ "</div>"
1710
+ ],
1711
+ "text/plain": [
1712
+ " desease_condition \\\n",
1713
+ "330 ['alzheimer disease', 'dementia', 'brain disea... \n",
1714
+ "331 ['alzheimer disease', 'dementia', 'brain disea... \n",
1715
+ "\n",
1716
+ " text \n",
1717
+ "330 nct_id: NCT02164643\\nsummary: A Multicenter na... \n",
1718
+ "331 nct_id: NCT02164643\\nsummary: A Multicenter na... "
1719
+ ]
1720
+ },
1721
+ "execution_count": 111,
1722
+ "metadata": {},
1723
+ "output_type": "execute_result"
1724
+ }
1725
+ ],
1726
+ "source": [
1727
+ "df2= df.iloc[I[0]]\n",
1728
+ "df2"
1729
+ ]
1730
+ },
1731
+ {
1732
+ "cell_type": "code",
1733
+ "execution_count": 113,
1734
+ "id": "af5bf8e2-43b6-47af-affa-5111789371ad",
1735
+ "metadata": {},
1736
+ "outputs": [
1737
+ {
1738
+ "data": {
1739
+ "text/plain": [
1740
+ "'nct_id: NCT02164643\\nsummary: A Multicenter national longitudinal cohort study including at least 800 individuals consecutively recruited from French Research Memory Centers and followed-up over 24 month and included in Memento.\\nintervention_type: Drug\\nintervention_name: Florbetapir (18F)\\nintervention_description: nan\\nkeywords: [\"Alzheimer\\'s disease\", \\'Mild Cognitive Impairment\\']'"
1741
+ ]
1742
+ },
1743
+ "execution_count": 113,
1744
+ "metadata": {},
1745
+ "output_type": "execute_result"
1746
+ }
1747
+ ],
1748
+ "source": [
1749
+ "df2.iloc[1].text"
1750
+ ]
1751
+ },
1752
+ {
1753
+ "cell_type": "code",
1754
+ "execution_count": null,
1755
+ "id": "f3899f81-e120-475c-97ed-080cb7f46510",
1756
+ "metadata": {},
1757
+ "outputs": [],
1758
+ "source": []
1759
+ }
1760
+ ],
1761
+ "metadata": {
1762
+ "kernelspec": {
1763
+ "display_name": "Python 3 (ipykernel)",
1764
+ "language": "python",
1765
+ "name": "python3"
1766
+ },
1767
+ "language_info": {
1768
+ "codemirror_mode": {
1769
+ "name": "ipython",
1770
+ "version": 3
1771
+ },
1772
+ "file_extension": ".py",
1773
+ "mimetype": "text/x-python",
1774
+ "name": "python",
1775
+ "nbconvert_exporter": "python",
1776
+ "pygments_lexer": "ipython3",
1777
+ "version": "3.11.9"
1778
+ }
1779
+ },
1780
+ "nbformat": 4,
1781
+ "nbformat_minor": 5
1782
+ }
database.ipynb CHANGED
@@ -9,7 +9,7 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": null,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
@@ -19,7 +19,7 @@
19
  },
20
  {
21
  "cell_type": "code",
22
- "execution_count": null,
23
  "metadata": {},
24
  "outputs": [],
25
  "source": [
@@ -40,7 +40,7 @@
40
  "outputs": [],
41
  "source": [
42
  "# Load knowledge graph\n",
43
- "entity_embeddings = pd.read_csv('./data/entity_embeddings.csv', index_col=0)\n",
44
  "entity_embeddings[\"embedding\"] = entity_embeddings[\"embedding\"].apply(\n",
45
  " lambda x: x[1:-1])\n",
46
  "\n",
@@ -103,7 +103,7 @@
103
  "source": [
104
  "# Load clinical trials\n",
105
  "\n",
106
- "relation_embeddings = pd.read_csv('./data/relation_embeddings.csv', index_col=0)\n",
107
  "relation_embeddings[\"embedding\"] = relation_embeddings[\"embedding\"].apply(\n",
108
  " lambda x: x[1:-1])\n",
109
  "\n",
@@ -126,7 +126,7 @@
126
  " with conn.begin():\n",
127
  " for index, row in relation_embeddings.iterrows():\n",
128
  " sql = text(\"\"\"\n",
129
- " INSERT INTO Test.ClinicalTrials \n",
130
  " (embedding, label, uri) \n",
131
  " VALUES (TO_VECTOR(:embedding), :label, :uri)\n",
132
  " \"\"\")\n",
@@ -134,7 +134,7 @@
134
  " 'embedding': str(row['embedding']),\n",
135
  " 'label': row['label'], \n",
136
  " 'uri': row['uri']\n",
137
- " })\n"
138
  ]
139
  },
140
  {
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 8,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
19
  },
20
  {
21
  "cell_type": "code",
22
+ "execution_count": 9,
23
  "metadata": {},
24
  "outputs": [],
25
  "source": [
 
40
  "outputs": [],
41
  "source": [
42
  "# Load knowledge graph\n",
43
+ "entity_embeddings = pd.read_csv('./entity_embeddings.csv', index_col=0)\n",
44
  "entity_embeddings[\"embedding\"] = entity_embeddings[\"embedding\"].apply(\n",
45
  " lambda x: x[1:-1])\n",
46
  "\n",
 
103
  "source": [
104
  "# Load clinical trials\n",
105
  "\n",
106
+ "relation_embeddings = pd.read_csv('./relation_embeddings.csv', index_col=0)\n",
107
  "relation_embeddings[\"embedding\"] = relation_embeddings[\"embedding\"].apply(\n",
108
  " lambda x: x[1:-1])\n",
109
  "\n",
 
126
  " with conn.begin():\n",
127
  " for index, row in relation_embeddings.iterrows():\n",
128
  " sql = text(\"\"\"\n",
129
+ " INSERT INTO Test.RelationEmbeddings \n",
130
  " (embedding, label, uri) \n",
131
  " VALUES (TO_VECTOR(:embedding), :label, :uri)\n",
132
  " \"\"\")\n",
 
134
  " 'embedding': str(row['embedding']),\n",
135
  " 'label': row['label'], \n",
136
  " 'uri': row['uri']\n",
137
+ " })"
138
  ]
139
  },
140
  {
graph_visualization.mlapp CHANGED
Binary files a/graph_visualization.mlapp and b/graph_visualization.mlapp differ