{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9db3b813-22dc-4209-86d2-42e935f5f5dd",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.document_loaders.csv_loader import CSVLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"import pandas as pd\n",
"import langchain\n",
"import os\n",
"import openai\n",
"import ast\n",
"from langchain import OpenAI\n",
"from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.document_loaders import UnstructuredURLLoader\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.vectorstores import FAISS"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fb48dccd-37e5-484a-a2fc-c482839b9ed9",
"metadata": {},
"outputs": [],
"source": [
"# loader = CSVLoader(file_path=\"trials/brief_summaries.csv\")\n",
"# data = loader.load()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "cdcc3107-e6a8-47ca-bd89-e381fbcf9b9e",
"metadata": {},
"outputs": [],
"source": [
"# df= pd.read_csv(\"trials/brief_summaries.txt\", delimiter=\"|\")\n",
"# df.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c6a94c14-b197-4eff-9941-1ca52069cd5c",
"metadata": {},
"outputs": [],
"source": [
"# df.head(20)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "95c000ff-bf0c-4489-8643-93a238db41dc",
"metadata": {},
"outputs": [],
"source": [
"os.environ['OPENAI_API_KEY']=\"sk-proj-CG2E98bSWs53X2eWO0Z4T3BlbkFJLm7H1vfkbua0zP548CKQ\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e2e03936-fcce-4287-bfe6-e31f1b69f693",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aldan.creo/miniconda3/envs/hackupc/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `OpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import OpenAI`.\n",
" warn_deprecated(\n"
]
}
],
"source": [
"llm= OpenAI(\n",
" temperature=0.6, max_tokens=500\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "aede3d75-2441-44f6-b3cf-2f86f050da24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(440517, 2)\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
" text | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" ['marijuana abuse', 'substance-related disorde... | \n",
" nct_id: NCT03055377\\nsummary: This is a 12-wee... | \n",
"
\n",
" \n",
" 1 | \n",
" ['marijuana abuse', 'substance-related disorde... | \n",
" nct_id: NCT03055377\\nsummary: This is a 12-wee... | \n",
"
\n",
" \n",
" 2 | \n",
" ['tuberculosis', 'latent tuberculosis', 'infec... | \n",
" nct_id: NCT03042754\\nsummary: Early diagnosis ... | \n",
"
\n",
" \n",
" 3 | \n",
" ['heart failure', 'heart diseases', 'cardiovas... | \n",
" nct_id: NCT03035123\\nsummary: The EduStra-HF s... | \n",
"
\n",
" \n",
" 4 | \n",
" ['lymphoma', 'neoplasms by histologic type', '... | \n",
" nct_id: NCT02272751\\nsummary: This study will ... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" desease_condition \\\n",
"0 ['marijuana abuse', 'substance-related disorde... \n",
"1 ['marijuana abuse', 'substance-related disorde... \n",
"2 ['tuberculosis', 'latent tuberculosis', 'infec... \n",
"3 ['heart failure', 'heart diseases', 'cardiovas... \n",
"4 ['lymphoma', 'neoplasms by histologic type', '... \n",
"\n",
" text \n",
"0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
"1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
"2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n",
"3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n",
"4 nct_id: NCT02272751\\nsummary: This study will ... "
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials= pd.read_csv(\"clinical_trials.csv\")\n",
"print(df_trials.shape)\n",
"df_trials.head()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "65589fc4-4f55-4b72-9935-c3b163fde0e2",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" ['marijuana abuse', 'substance-related disorde... | \n",
"
\n",
" \n",
" 1 | \n",
" ['marijuana abuse', 'substance-related disorde... | \n",
"
\n",
" \n",
" 2 | \n",
" ['tuberculosis', 'latent tuberculosis', 'infec... | \n",
"
\n",
" \n",
" 3 | \n",
" ['heart failure', 'heart diseases', 'cardiovas... | \n",
"
\n",
" \n",
" 4 | \n",
" ['lymphoma', 'neoplasms by histologic type', '... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" desease_condition\n",
"0 ['marijuana abuse', 'substance-related disorde...\n",
"1 ['marijuana abuse', 'substance-related disorde...\n",
"2 ['tuberculosis', 'latent tuberculosis', 'infec...\n",
"3 ['heart failure', 'heart diseases', 'cardiovas...\n",
"4 ['lymphoma', 'neoplasms by histologic type', '..."
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials_filtered=df_trials[['desease_condition']]\n",
"df_trials_filtered.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "88e28056-e340-416c-a1a9-4a6c29556dc7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"['marijuana abuse', 'substance-related disorders', 'chemically-induced disorders', 'mental disorders']\""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials_filtered['desease_condition'].iloc[0]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "5bd8f876-0480-40a5-a32f-ca7ec137a70f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\ariji\\AppData\\Local\\Temp\\ipykernel_22340\\16068817.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" marijuana abuse, substance-related disorders, ... | \n",
"
\n",
" \n",
" 1 | \n",
" marijuana abuse, substance-related disorders, ... | \n",
"
\n",
" \n",
" 2 | \n",
" tuberculosis, latent tuberculosis, infections,... | \n",
"
\n",
" \n",
" 3 | \n",
" heart failure, heart diseases, cardiovascular ... | \n",
"
\n",
" \n",
" 4 | \n",
" lymphoma, neoplasms by histologic type, neopla... | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 440512 | \n",
" obesity, overweight, overnutrition, nutrition ... | \n",
"
\n",
" \n",
" 440513 | \n",
" obesity, overweight, overnutrition, nutrition ... | \n",
"
\n",
" \n",
" 440514 | \n",
" obesity, overweight, overnutrition, nutrition ... | \n",
"
\n",
" \n",
" 440515 | \n",
" autistic disorder, autism spectrum disorder, c... | \n",
"
\n",
" \n",
" 440516 | \n",
" autistic disorder, autism spectrum disorder, c... | \n",
"
\n",
" \n",
"
\n",
"
440517 rows × 1 columns
\n",
"
"
],
"text/plain": [
" desease_condition\n",
"0 marijuana abuse, substance-related disorders, ...\n",
"1 marijuana abuse, substance-related disorders, ...\n",
"2 tuberculosis, latent tuberculosis, infections,...\n",
"3 heart failure, heart diseases, cardiovascular ...\n",
"4 lymphoma, neoplasms by histologic type, neopla...\n",
"... ...\n",
"440512 obesity, overweight, overnutrition, nutrition ...\n",
"440513 obesity, overweight, overnutrition, nutrition ...\n",
"440514 obesity, overweight, overnutrition, nutrition ...\n",
"440515 autistic disorder, autism spectrum disorder, c...\n",
"440516 autistic disorder, autism spectrum disorder, c...\n",
"\n",
"[440517 rows x 1 columns]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def list_to_string(disease_list):\n",
" disease_list= ast.literal_eval(disease_list)\n",
" return ', '.join(disease_list)\n",
"df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n",
"df_trials_filtered"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "bbbb22e6-4883-4869-8ccc-95696bc67b1b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 marijuana abuse, substance-related disorders, ...\n",
"1 marijuana abuse, substance-related disorders, ...\n",
"2 tuberculosis, latent tuberculosis, infections,...\n",
"3 heart failure, heart diseases, cardiovascular ...\n",
"4 lymphoma, neoplasms by histologic type, neopla...\n",
"Name: desease_condition, dtype: object"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials_filtered['desease_condition'].head()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7f4e7ceb-8bfd-4294-a850-8935f88b6555",
"metadata": {},
"outputs": [],
"source": [
"df_trials_filtered.to_csv(\"diseases.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "af1c5c2b-24a0-44a1-9e5d-7ee89ca4cccf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"440517"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loader= CSVLoader(file_path=\"./diseases.csv\", encoding=\"utf-8\")\n",
"data = loader.load()\n",
"len(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cab89218-41ca-4048-886d-bc2c1c9b30bc",
"metadata": {},
"outputs": [],
"source": [
"embeddings = OpenAIEmbeddings()\n",
"vectorstore = FAISS.from_documents(data, embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "225ade4a-d004-44cc-a5ff-22ce2bfcac32",
"metadata": {},
"outputs": [],
"source": [
"file_path= \"vector_index.pkl\"\n",
"with open(file_path, \"wb\") as f:\n",
" pickle.dump(vectorstore, f)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "11912a93-ad02-41cb-8bce-2750c947fa74",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(440517, 2)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
" text | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" ['marijuana abuse', 'substance-related disorde... | \n",
" nct_id: NCT03055377\\nsummary: This is a 12-wee... | \n",
"
\n",
" \n",
" 1 | \n",
" ['marijuana abuse', 'substance-related disorde... | \n",
" nct_id: NCT03055377\\nsummary: This is a 12-wee... | \n",
"
\n",
" \n",
" 2 | \n",
" ['tuberculosis', 'latent tuberculosis', 'infec... | \n",
" nct_id: NCT03042754\\nsummary: Early diagnosis ... | \n",
"
\n",
" \n",
" 3 | \n",
" ['heart failure', 'heart diseases', 'cardiovas... | \n",
" nct_id: NCT03035123\\nsummary: The EduStra-HF s... | \n",
"
\n",
" \n",
" 4 | \n",
" ['lymphoma', 'neoplasms by histologic type', '... | \n",
" nct_id: NCT02272751\\nsummary: This study will ... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" desease_condition \\\n",
"0 ['marijuana abuse', 'substance-related disorde... \n",
"1 ['marijuana abuse', 'substance-related disorde... \n",
"2 ['tuberculosis', 'latent tuberculosis', 'infec... \n",
"3 ['heart failure', 'heart diseases', 'cardiovas... \n",
"4 ['lymphoma', 'neoplasms by histologic type', '... \n",
"\n",
" text \n",
"0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
"1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
"2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n",
"3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n",
"4 nct_id: NCT02272751\\nsummary: This study will ... "
]
},
"execution_count": 98,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials= pd.read_csv(\"clinical_trials.csv\")\n",
"print(df_trials.shape)\n",
"df_trials.head()"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "a8113876-f38e-4c7a-891e-17dc51e2bacf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"nct_id: NCT03055377\\nsummary: This is a 12-week randomized, placebo-controlled trial of N-acetylcysteine for cannabis use disorder (CUD) in youth (N=192). Participants will be randomized to double-blind NAC or PBO, yielding two equally-allocated treatment groups. All participants will receive brief weekly cannabis cessation counseling and medication management. The primary efficacy outcome will be the proportion of negative urine cannabinoid tests during the 12-week active treatment, compared between groups.\\nintervention_type: Drug\\nintervention_name: N-acetyl cysteine\\nintervention_description: N-acetylcysteine 1200 mg twice daily for 12 weeks (administered orally)\\nkeywords: ['cannabis', 'marijuana', 'youth', 'adolescent', 'pharmacotherapy', 'medication', 'n-acetylcysteine', 'trial']\""
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials.iloc[0].text"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "c9ace03c-cf71-4605-adeb-9c6ee10e158a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2)"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df= df_trials[:1000]\n",
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85a985c5-845b-464e-9527-e806b6883741",
"metadata": {},
"outputs": [],
"source": [
"embeddings= OpenAIEmbeddings()\n",
"vectorindex_openapi= FAISS.from"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25b31e55-2961-474d-92d8-5963f2c6bf84",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "918c078c-46fe-4d7b-9748-88c52a5b004a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "36e83202-97ad-425d-95ae-075a1e26a34e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d705875-5dd7-4c71-8d94-99c101020ac0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "50c15b4d-9f65-4385-9aa6-3f782bf57775",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f2818abe-1a43-4d7d-92a7-7562812bf43d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" marijuana abuse, substance-related disorders, ... | \n",
"
\n",
" \n",
" 1 | \n",
" marijuana abuse, substance-related disorders, ... | \n",
"
\n",
" \n",
" 2 | \n",
" tuberculosis, latent tuberculosis, infections,... | \n",
"
\n",
" \n",
" 3 | \n",
" heart failure, heart diseases, cardiovascular ... | \n",
"
\n",
" \n",
" 4 | \n",
" lymphoma, neoplasms by histologic type, neopla... | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 440512 | \n",
" obesity, overweight, overnutrition, nutrition ... | \n",
"
\n",
" \n",
" 440513 | \n",
" obesity, overweight, overnutrition, nutrition ... | \n",
"
\n",
" \n",
" 440514 | \n",
" obesity, overweight, overnutrition, nutrition ... | \n",
"
\n",
" \n",
" 440515 | \n",
" autistic disorder, autism spectrum disorder, c... | \n",
"
\n",
" \n",
" 440516 | \n",
" autistic disorder, autism spectrum disorder, c... | \n",
"
\n",
" \n",
"
\n",
"
440517 rows × 1 columns
\n",
"
"
],
"text/plain": [
" desease_condition\n",
"0 marijuana abuse, substance-related disorders, ...\n",
"1 marijuana abuse, substance-related disorders, ...\n",
"2 tuberculosis, latent tuberculosis, infections,...\n",
"3 heart failure, heart diseases, cardiovascular ...\n",
"4 lymphoma, neoplasms by histologic type, neopla...\n",
"... ...\n",
"440512 obesity, overweight, overnutrition, nutrition ...\n",
"440513 obesity, overweight, overnutrition, nutrition ...\n",
"440514 obesity, overweight, overnutrition, nutrition ...\n",
"440515 autistic disorder, autism spectrum disorder, c...\n",
"440516 autistic disorder, autism spectrum disorder, c...\n",
"\n",
"[440517 rows x 1 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_trials_filtered = pd.read_csv(\"diseases.csv\")\n",
"df_trials_filtered"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "c89e3cf6-a376-4029-9c04-0f5e664a2237",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(100, 1)"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df2= df_trials_filtered[:100]\n",
"df2.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c5012bcf-3e25-4f21-a29c-6bdbdafbb8c7",
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"client= OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "40a480bd-6754-40b6-870c-42d10ce9a960",
"metadata": {},
"outputs": [],
"source": [
"def get_embeddings(text):\n",
" response= client.embeddings.create(\n",
" input= text,\n",
" dimensions= 128,\n",
" model= \"text-embedding-3-small\"\n",
" )\n",
" return response.data[0].embedding"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ef6d6b62-de0b-4bc6-a6eb-847ab8e99da5",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'sentence_transformers'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m:1\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sentence_transformers'"
]
}
],
"source": [
"%%time\n",
"from sentence_transformers import SentenceTransformer\n",
"\n",
"encoder= SentenceTransformer(\"all-mpnet-base-v2\")\n",
"vectors= encoder.encode(df2.desease_condition)\n",
"vectors.shape"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "7966d754-56d7-4555-a6c6-6a13772fb000",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: total: 62.5 ms\n",
"Wall time: 24.7 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
":1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
" embeddings | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" marijuana abuse, substance-related disorders, ... | \n",
" [-0.05811865255236626, -0.023393018171191216, ... | \n",
"
\n",
" \n",
" 1 | \n",
" marijuana abuse, substance-related disorders, ... | \n",
" [-0.05811865255236626, -0.023393018171191216, ... | \n",
"
\n",
" \n",
" 2 | \n",
" tuberculosis, latent tuberculosis, infections,... | \n",
" [-0.03460180386900902, -0.084668830037117, 0.2... | \n",
"
\n",
" \n",
" 3 | \n",
" heart failure, heart diseases, cardiovascular ... | \n",
" [-0.08236236125230789, -0.1235777735710144, 0.... | \n",
"
\n",
" \n",
" 4 | \n",
" lymphoma, neoplasms by histologic type, neopla... | \n",
" [-0.1227850392460823, 0.07155642658472061, 0.1... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" desease_condition \\\n",
"0 marijuana abuse, substance-related disorders, ... \n",
"1 marijuana abuse, substance-related disorders, ... \n",
"2 tuberculosis, latent tuberculosis, infections,... \n",
"3 heart failure, heart diseases, cardiovascular ... \n",
"4 lymphoma, neoplasms by histologic type, neopla... \n",
"\n",
" embeddings \n",
"0 [-0.05811865255236626, -0.023393018171191216, ... \n",
"1 [-0.05811865255236626, -0.023393018171191216, ... \n",
"2 [-0.03460180386900902, -0.084668830037117, 0.2... \n",
"3 [-0.08236236125230789, -0.1235777735710144, 0.... \n",
"4 [-0.1227850392460823, 0.07155642658472061, 0.1... "
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"df2['embeddings']= df2['desease_condition'].apply(get_embeddings)\n",
"df2.head()"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "2711980a-d1c0-441e-ae9a-531500b7b7cd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(128,)"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.array(df2.iloc[0].embeddings).shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "13e40184-999d-4363-be44-18ba9fb6745a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\ariji\\Anaconda3\\envs\\python310\\lib\\site-packages\\huggingface_hub\\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[1;32mIn [23]\u001b[0m, in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msentence_transformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SentenceTransformer\n\u001b[0;32m 3\u001b[0m encoder\u001b[38;5;241m=\u001b[39m SentenceTransformer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mall-mpnet-base-v2\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 4\u001b[0m vectors\u001b[38;5;241m=\u001b[39m \u001b[43mencoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf_trials_filtered\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdesease_condition\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m vectors\u001b[38;5;241m.\u001b[39mshape\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:371\u001b[0m, in \u001b[0;36mSentenceTransformer.encode\u001b[1;34m(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings)\u001b[0m\n\u001b[0;32m 368\u001b[0m features\u001b[38;5;241m.\u001b[39mupdate(extra_features)\n\u001b[0;32m 370\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m--> 371\u001b[0m out_features \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 372\u001b[0m out_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence_embedding\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m truncate_embeddings(\n\u001b[0;32m 373\u001b[0m out_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence_embedding\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtruncate_dim\n\u001b[0;32m 374\u001b[0m )\n\u001b[0;32m 376\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_value \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_embeddings\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 138\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 139\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 140\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\sentence_transformers\\models\\Transformer.py:98\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, features)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m features:\n\u001b[0;32m 96\u001b[0m trans_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m---> 98\u001b[0m output_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_model(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mtrans_features, return_dict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 99\u001b[0m output_tokens \u001b[38;5;241m=\u001b[39m output_states[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 101\u001b[0m features\u001b[38;5;241m.\u001b[39mupdate({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_embeddings\u001b[39m\u001b[38;5;124m\"\u001b[39m: output_tokens, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m: features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]})\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:548\u001b[0m, in \u001b[0;36mMPNetModel.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m 546\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[0;32m 547\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(input_ids\u001b[38;5;241m=\u001b[39minput_ids, position_ids\u001b[38;5;241m=\u001b[39mposition_ids, inputs_embeds\u001b[38;5;241m=\u001b[39minputs_embeds)\n\u001b[1;32m--> 548\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 549\u001b[0m \u001b[43m \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 550\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 551\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 552\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 553\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 554\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 555\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 556\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 557\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:338\u001b[0m, in \u001b[0;36mMPNetEncoder.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[0;32m 336\u001b[0m all_hidden_states \u001b[38;5;241m=\u001b[39m all_hidden_states \u001b[38;5;241m+\u001b[39m (hidden_states,)\n\u001b[1;32m--> 338\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m layer_module(\n\u001b[0;32m 339\u001b[0m hidden_states,\n\u001b[0;32m 340\u001b[0m attention_mask,\n\u001b[0;32m 341\u001b[0m head_mask[i],\n\u001b[0;32m 342\u001b[0m position_bias,\n\u001b[0;32m 343\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[0;32m 344\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[0;32m 345\u001b[0m )\n\u001b[0;32m 346\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 348\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:308\u001b[0m, in \u001b[0;36mMPNetLayer.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, position_bias, output_attentions, **kwargs)\u001b[0m\n\u001b[0;32m 305\u001b[0m outputs \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m1\u001b[39m:] \u001b[38;5;66;03m# add self attentions if we output attention weights\u001b[39;00m\n\u001b[0;32m 307\u001b[0m intermediate_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mintermediate(attention_output)\n\u001b[1;32m--> 308\u001b[0m layer_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\u001b[43m(\u001b[49m\u001b[43mintermediate_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_output\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 309\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (layer_output,) \u001b[38;5;241m+\u001b[39m outputs\n\u001b[0;32m 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:275\u001b[0m, in \u001b[0;36mMPNetOutput.forward\u001b[1;34m(self, hidden_states, input_tensor)\u001b[0m\n\u001b[0;32m 274\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor, input_tensor: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m--> 275\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdense\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 276\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(hidden_states)\n\u001b[0;32m 277\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLayerNorm(hidden_states \u001b[38;5;241m+\u001b[39m input_tensor)\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"encoder= SentenceTransformer(\"all-mpnet-base-v2\")\n",
"vectors= encoder.encode(df2.desease_condition)\n",
"vectors.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f626382-75f1-4864-8574-bbcd183b926a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e064857-2480-4ac2-911a-1c5a2ffb62d6",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 102,
"id": "c0a719a6-ee06-4f6d-a64b-8fc74c7afbca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.01245328 0.07695839 0.01802037 ... 0.0093326 -0.03474615\n",
" -0.02757339]\n",
" [ 0.01846956 0.08282936 0.01921537 ... 0.01068991 -0.03402989\n",
" -0.02075185]\n",
" [ 0.01822413 -0.05593034 -0.00288358 ... 0.03525289 -0.05427228\n",
" -0.03371295]\n",
" ...\n",
" [ 0.04031958 0.0040107 0.02156032 ... 0.01568209 -0.04320977\n",
" -0.02990234]\n",
" [ 0.0399131 0.00027251 0.02207735 ... 0.01440835 -0.04246744\n",
" -0.02869584]\n",
" [ 0.03773859 -0.00315346 0.0207725 ... 0.01205995 -0.04628598\n",
" -0.02870333]]\n"
]
}
],
"source": [
"print(vectors)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "f21c08d5-596a-4fc4-ba72-2de8624e6c50",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"768"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dim= vectors.shape[1]\n",
"dim "
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "8ed985f4-9402-431f-bfba-1236ba16b895",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" >"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import faiss\n",
"\n",
"index= faiss.IndexFlatL2(dim)\n",
"index"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "c0f1dde8-c173-45ba-957b-f4622fe6f0ee",
"metadata": {},
"outputs": [],
"source": [
"index.add(vectors)"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "71dad860-1f19-4166-a309-c9ce15f24792",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(768,)\n"
]
},
{
"data": {
"text/plain": [
"array([-8.82369652e-03, 8.50650743e-02, 2.08267733e-03, 6.77651772e-03,\n",
" -2.86661759e-02, -8.71188380e-03, 6.99447095e-02, 5.04214764e-02,\n",
" 3.58386151e-02, 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n",
" 2.27009598e-03, 2.10810862e-02, 2.66138893e-02, 1.90623086e-02,\n",
" 4.44708914e-02, 2.96202525e-02, 5.42085357e-02, -2.34859088e-03,\n",
" -9.87798795e-02, -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n",
" 5.31156994e-02, -1.37044629e-02, 2.92537250e-02, -2.61334293e-02,\n",
" -1.21854078e-04, -2.36813519e-02, -3.81283499e-02, -1.79494768e-02,\n",
" -6.29265187e-03, 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n",
" -1.28973828e-04, 4.01791967e-02, 4.21229303e-02, -8.72302521e-03,\n",
" 7.44823692e-03, 7.68032745e-02, 6.50246907e-03, 3.40298638e-02,\n",
" -1.80711355e-02, -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n",
" -3.34868580e-02, 1.05205458e-02, 2.08975170e-02, 4.36686277e-02,\n",
" 3.47612537e-02, -4.99080680e-02, 4.44446988e-02, 5.57280704e-03,\n",
" -2.31200755e-02, -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n",
" 4.67903316e-02, 1.91651955e-02, -5.12171052e-02, 2.46807020e-02,\n",
" -5.52081019e-02, 3.50596346e-02, -7.01438356e-03, -3.36519890e-02,\n",
" -1.41502097e-02, -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n",
" -1.30138136e-02, 5.91567121e-02, 3.37168351e-02, 3.01467292e-02,\n",
" -4.59552221e-02, 1.37365120e-03, 1.00179566e-02, 6.98126853e-04,\n",
" 3.58139984e-02, 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n",
" 4.75908853e-02, 5.48331346e-03, -6.41460950e-03, -1.23906611e-02,\n",
" 5.82688041e-02, -1.60842277e-02, -2.95833423e-04, 6.97355811e-03,\n",
" 2.48331465e-02, -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n",
" 1.52637456e-02, 7.01832073e-03, 5.50601333e-02, 4.35096538e-03,\n",
" 2.36319732e-02, -1.38118947e-02, -7.24233836e-02, 9.39742289e-03,\n",
" -2.66901590e-02, 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n",
" 3.08373477e-03, 7.12765753e-02, 7.13613164e-03, 3.62721197e-02,\n",
" -4.53250594e-02, 2.54001115e-02, -2.54253373e-02, -1.23151275e-03,\n",
" -1.34750446e-02, -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n",
" -7.31003610e-03, 2.65329964e-02, -2.79857730e-03, 4.20840643e-02,\n",
" 3.20205763e-02, -1.19518824e-02, -5.77116087e-02, 9.88688134e-03,\n",
" 1.86814573e-02, -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n",
" -3.33383717e-02, -5.52933291e-02, 5.64148054e-02, 4.92153503e-02,\n",
" 3.33690383e-02, -3.92963700e-02, -6.91099390e-02, 3.79911740e-03,\n",
" -1.74410697e-02, -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n",
" 2.61192098e-02, -2.74193864e-02, 6.92490395e-03, -4.64810384e-03,\n",
" -8.99862905e-04, 1.02159111e-02, -4.81114909e-02, 1.22787328e-02,\n",
" -9.32844076e-03, -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n",
" -1.60810221e-02, -2.20200215e-02, 2.32051890e-02, -5.07331975e-02,\n",
" -1.01248249e-02, 5.62567115e-02, -2.60966737e-03, 9.27545596e-03,\n",
" 5.32410555e-02, 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n",
" -2.12969314e-02, 9.82244611e-02, -2.47648880e-02, 7.06253499e-02,\n",
" 8.71159416e-03, -2.73140483e-02, 5.59884915e-03, -2.14829091e-02,\n",
" -6.67077005e-02, 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n",
" 3.77488993e-02, -1.37352264e-02, -2.85069812e-02, 1.81708820e-02,\n",
" -4.07746173e-02, -4.71230270e-03, -1.59605164e-02, -1.25815195e-03,\n",
" -6.59954594e-03, -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n",
" 3.07933297e-02, 1.27626080e-02, -4.34489138e-02, 9.91576444e-03,\n",
" 1.25470785e-02, -8.67356583e-02, 1.26097840e-03, 3.24709825e-02,\n",
" -6.92409948e-02, -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n",
" -2.35359464e-02, -2.95022167e-02, 2.88009271e-02, -3.26618887e-02,\n",
" 7.09307985e-03, 3.09435464e-03, -5.09097055e-02, 3.54242921e-02,\n",
" 5.37336655e-02, 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n",
" 2.93767708e-03, 2.27999203e-02, 1.02668423e-02, 3.35033536e-02,\n",
" -8.28316063e-02, -4.17127199e-02, -1.23034064e-02, 2.38543525e-02,\n",
" -3.72257493e-02, 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n",
" 5.74519299e-03, 2.89844945e-02, -2.21337453e-02, 2.34603398e-02,\n",
" 6.33142609e-03, -2.24104542e-02, 1.47326495e-02, 1.98041964e-02,\n",
" 3.05697713e-02, -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n",
" 3.88860740e-02, -3.97440195e-02, -4.70216498e-02, 1.02172708e-02,\n",
" -3.37972888e-03, -8.54947045e-03, 4.81354557e-02, 4.99849804e-02,\n",
" 7.11378129e-03, -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n",
" 5.24284095e-02, 2.16708388e-02, -4.00698110e-02, 5.15380092e-02,\n",
" -6.03203699e-02, -6.50304696e-03, -1.03860423e-02, -7.47132823e-02,\n",
" 3.59848235e-03, -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n",
" -2.88047381e-02, -2.81904116e-02, 1.52729014e-02, -1.55570190e-02,\n",
" 1.34619148e-02, 2.34364290e-02, 3.10326237e-02, -4.70464528e-02,\n",
" -2.43550166e-02, -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n",
" -5.30204549e-03, 5.52049950e-02, 4.50828709e-02, -7.30262510e-03,\n",
" 5.56289777e-02, -9.46066808e-03, -3.37345451e-02, -1.87659152e-02,\n",
" 3.57284099e-02, 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n",
" 2.96422077e-04, 4.22447585e-02, 4.97253910e-02, 6.03130311e-02,\n",
" 1.32281650e-02, 2.35939436e-02, -1.59284715e-02, 4.46444489e-02,\n",
" -1.68315917e-02, 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n",
" 8.99049267e-03, 4.74606343e-02, 6.70041004e-03, -1.15184486e-03,\n",
" 2.69540539e-03, -2.77549177e-02, -1.33260442e-02, 2.60788556e-02,\n",
" 4.35438640e-02, -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n",
" 2.93240137e-02, 1.82274636e-03, -1.40310880e-02, -1.91633645e-02,\n",
" 1.18790809e-02, -4.65121269e-02, -4.19883654e-02, -2.69681774e-02,\n",
" -3.23035605e-02, -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n",
" -2.55833156e-02, -5.73152229e-02, 3.30126472e-02, -7.90146552e-03,\n",
" -1.08651863e-02, 1.10474667e-02, 3.03509296e-03, 1.55274626e-02,\n",
" 1.05599947e-02, -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n",
" 3.77239436e-02, 9.44003314e-02, -4.80610691e-02, 4.73537892e-02,\n",
" 3.40655483e-02, 7.88806472e-03, -2.84915343e-02, 7.96849206e-02,\n",
" 1.57442074e-02, -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n",
" -1.72730908e-01, -8.72075930e-02, 2.86346450e-02, 2.16962174e-02,\n",
" -4.80199270e-02, 6.49317261e-03, 1.67240556e-02, -2.56227311e-02,\n",
" 2.19670162e-02, -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n",
" -2.89566331e-02, 1.19498251e-02, -2.33849231e-02, -2.69133616e-02,\n",
" -1.46602485e-02, 1.18886270e-02, 1.64973717e-02, -3.90495770e-02,\n",
" -3.45575088e-03, 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n",
" 2.10017413e-02, 2.74998210e-02, 3.03551817e-04, -1.15796946e-01,\n",
" -4.66962112e-03, -4.80118394e-02, -3.55160870e-02, -4.72528581e-03,\n",
" -4.29739058e-02, -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n",
" 1.98413953e-02, -7.27679394e-03, 2.27117930e-02, -2.59338003e-02,\n",
" 4.31442596e-02, 1.07885078e-02, -2.47129947e-02, -4.14506458e-02,\n",
" 4.40958813e-02, 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n",
" 1.13289580e-02, -5.57265691e-02, 1.71151303e-03, -1.24145029e-02,\n",
" -3.57853901e-03, -4.86295968e-02, -5.14956787e-02, 4.79425713e-02,\n",
" -3.24050151e-02, 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n",
" 8.20766483e-03, -6.27530292e-02, -1.30661400e-02, -3.52081768e-02,\n",
" 4.83807474e-02, 9.81860235e-03, 1.14539362e-01, -1.88471414e-02,\n",
" 6.07751869e-02, -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n",
" 2.64345529e-03, 3.07400171e-02, -4.31060083e-02, -6.19985871e-02,\n",
" 5.50477020e-03, 1.62547994e-02, -8.26352183e-03, 7.56437238e-03,\n",
" -4.79784003e-03, 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n",
" 1.41595434e-02, 5.31185642e-02, 6.78585656e-03, 6.56357184e-02,\n",
" -5.06135784e-02, -3.05179805e-02, 7.06539825e-02, -3.55644710e-02,\n",
" -4.92612133e-03, 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n",
" -1.86746120e-02, 2.49281265e-02, -4.92450967e-03, 1.66887734e-02,\n",
" 4.62210961e-02, 4.07794118e-02, 2.52511259e-02, -2.83305068e-02,\n",
" -2.78001893e-02, -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n",
" 1.09969089e-02, 1.69700030e-02, -8.59475043e-03, 4.70476560e-02,\n",
" 3.64770554e-02, 2.09835749e-02, 1.01236468e-02, 2.75151283e-02,\n",
" 4.33402918e-02, -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n",
" -6.10819347e-02, -2.86280159e-02, 4.68054451e-02, 1.29892454e-02,\n",
" -1.71940885e-02, -2.52429228e-02, 3.86423096e-02, -1.35919163e-02,\n",
" -5.27431667e-02, 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n",
" 3.23252901e-02, 5.03172688e-02, -4.45654802e-02, 2.90075876e-02,\n",
" -1.35373492e-02, 6.78209821e-03, -5.89249916e-02, 4.28890549e-02,\n",
" -2.36034058e-02, -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n",
" 1.45543357e-02, 1.07806427e-02, -6.06855676e-02, -4.95252907e-02,\n",
" 1.02004781e-02, 4.60227691e-02, -1.08090881e-02, 4.42408510e-02,\n",
" 4.15152796e-02, 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n",
" -2.70090066e-02, 2.68773828e-02, -1.97812133e-02, 2.25932393e-02,\n",
" -1.33560598e-02, -1.50896851e-02, -3.14053567e-03, 1.54051669e-02,\n",
" 1.86488125e-02, -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n",
" -2.37891261e-04, 1.84722953e-02, 3.60381305e-02, -5.85213909e-03,\n",
" 4.44293395e-02, -1.11264118e-03, -4.79441285e-02, 3.46464328e-02,\n",
" -2.53370814e-02, -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n",
" -4.38152434e-04, 4.08602282e-02, -2.29470823e-02, -1.89938806e-02,\n",
" -1.52037974e-04, 1.05516789e-02, 2.08601039e-02, -6.98119551e-02,\n",
" 3.66246551e-02, -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n",
" 6.51817098e-02, 4.29646857e-02, 2.56071109e-02, -3.28080021e-02,\n",
" 1.20534413e-02, 3.56224040e-03, -1.01593453e-02, -1.96505673e-04,\n",
" 4.33485657e-02, -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n",
" -1.40319867e-02, -3.63940969e-02, -3.09983976e-02, -4.19548260e-33,\n",
" 7.11604580e-02, 4.78382297e-02, 1.89297704e-03, -1.60731785e-02,\n",
" 2.53787991e-02, -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n",
" 1.68679946e-03, 1.92391127e-02, -2.20667192e-04, 1.32907527e-02,\n",
" 5.99487219e-03, 2.75156219e-02, -5.06000873e-03, -3.58465910e-02,\n",
" 8.20948277e-03, -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n",
" -1.09853260e-01, -3.66037302e-02, 3.55480015e-02, 4.23291475e-02,\n",
" 1.48312682e-02, 5.68749309e-02, 3.57767567e-02, 1.40728084e-02,\n",
" -4.00471613e-02, 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n",
" 1.24238459e-02, 1.20237898e-02, -7.69484974e-03, -3.30727436e-02,\n",
" -1.45808076e-02, 3.43246050e-02, 3.21143419e-02, -4.96741422e-02,\n",
" -5.27968369e-02, 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n",
" 2.77636852e-02, 5.90689071e-02, -2.61273161e-02, -6.95008039e-02,\n",
" -3.15576978e-02, -5.62214339e-03, -7.93884136e-03, -3.62196900e-02,\n",
" -8.26047733e-03, 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n",
" -2.52235290e-02, -3.88054736e-02, -2.00710595e-02, 1.50789914e-03,\n",
" -5.51338419e-02, -8.35673045e-03, -1.61523875e-02, -8.79513845e-02,\n",
" -5.28004877e-02, -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n",
" 4.44932319e-02, 8.69598426e-03, -1.14432694e-02, 4.47212979e-02,\n",
" 2.70624813e-02, -3.86100151e-02, -3.07358261e-02, 2.75634117e-02,\n",
" 1.48464069e-02, -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n",
" 8.05836394e-02, -1.69498641e-02, 4.44465503e-02, -2.09145956e-02,\n",
" -3.37407738e-02, 3.85780074e-02, -7.44559616e-02, 1.17512364e-02,\n",
" 1.01964204e-02, -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n",
" 2.54592765e-02, -1.46158040e-02, 5.46646416e-02, 1.43051194e-03,\n",
" 2.99116820e-02, 2.24273186e-02, -5.79927117e-03, -1.33864526e-02,\n",
" -2.52460372e-02, -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n",
" 3.38429734e-02, -2.11539529e-02, 7.17787817e-02, -7.78904185e-02,\n",
" -4.04084288e-02, 4.90567498e-02, -2.61603445e-02, 1.97753590e-02,\n",
" 4.97209951e-02, -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n",
" 2.68440694e-02, 3.29160057e-02, -8.24410375e-03, -1.33646047e-02,\n",
" -6.22822754e-02, -1.13362661e-02, -3.79339382e-02, -6.56360280e-05,\n",
" -1.08087100e-02, 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n",
" -2.54666172e-02, -3.05371322e-02, -1.53249800e-02, -9.87035502e-03,\n",
" 1.95337094e-07, -1.76476724e-02, 5.71432859e-02, -2.49180794e-02,\n",
" 5.85253723e-02, 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n",
" 4.07801419e-02, 4.13940698e-02, 2.55707726e-02, 2.18985360e-02,\n",
" -3.04434425e-03, -3.77355106e-02, -6.24866784e-02, -1.17468778e-02,\n",
" -4.82194684e-02, -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n",
" -2.48471629e-02, 8.05181568e-04, -4.85844910e-03, -5.16015477e-03,\n",
" 7.53483502e-03, -9.46175400e-03, -2.39896346e-02, -3.14654633e-02,\n",
" 1.50111094e-02, -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n",
" 3.08971256e-02, 1.72299352e-02, 5.93419448e-02, -5.74274361e-02,\n",
" -8.16087723e-02, -4.80572283e-02, -2.68838424e-02, -1.96331330e-02,\n",
" -9.15831141e-03, 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n",
" 9.21937004e-02, 1.37132118e-02, -1.19096776e-02, -4.09874134e-02,\n",
" 3.37628126e-02, -4.64820908e-03, -2.50304434e-02, 6.25852346e-02,\n",
" -1.24449311e-02, 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n",
" 5.08641489e-02, 2.53822445e-03, 5.25634140e-02, 1.14882430e-02,\n",
" 5.01894541e-02, -3.55215147e-02, -3.31749097e-02, -3.02003417e-03,\n",
" -5.36288768e-02, -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n",
" 9.56887701e-35, 2.55127084e-02, -1.44770980e-04, 1.96710341e-02,\n",
" -1.33620016e-02, -1.51910949e-02, -3.28495577e-02, -1.52465852e-03,\n",
" -2.65272055e-02, -4.35708016e-02, -1.75950192e-02, -2.20594816e-02],\n",
" dtype=float32)"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search_query= \"clinical trials related to alzheimers\"\n",
"vec= encoder.encode(search_query)\n",
"print(vec.shape)\n",
"vec"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "613fd415-4194-45e6-b9f3-9a7707845ad5",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 768)\n"
]
},
{
"data": {
"text/plain": [
"array([[-8.82369652e-03, 8.50650743e-02, 2.08267733e-03,\n",
" 6.77651772e-03, -2.86661759e-02, -8.71188380e-03,\n",
" 6.99447095e-02, 5.04214764e-02, 3.58386151e-02,\n",
" 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n",
" 2.27009598e-03, 2.10810862e-02, 2.66138893e-02,\n",
" 1.90623086e-02, 4.44708914e-02, 2.96202525e-02,\n",
" 5.42085357e-02, -2.34859088e-03, -9.87798795e-02,\n",
" -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n",
" 5.31156994e-02, -1.37044629e-02, 2.92537250e-02,\n",
" -2.61334293e-02, -1.21854078e-04, -2.36813519e-02,\n",
" -3.81283499e-02, -1.79494768e-02, -6.29265187e-03,\n",
" 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n",
" -1.28973828e-04, 4.01791967e-02, 4.21229303e-02,\n",
" -8.72302521e-03, 7.44823692e-03, 7.68032745e-02,\n",
" 6.50246907e-03, 3.40298638e-02, -1.80711355e-02,\n",
" -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n",
" -3.34868580e-02, 1.05205458e-02, 2.08975170e-02,\n",
" 4.36686277e-02, 3.47612537e-02, -4.99080680e-02,\n",
" 4.44446988e-02, 5.57280704e-03, -2.31200755e-02,\n",
" -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n",
" 4.67903316e-02, 1.91651955e-02, -5.12171052e-02,\n",
" 2.46807020e-02, -5.52081019e-02, 3.50596346e-02,\n",
" -7.01438356e-03, -3.36519890e-02, -1.41502097e-02,\n",
" -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n",
" -1.30138136e-02, 5.91567121e-02, 3.37168351e-02,\n",
" 3.01467292e-02, -4.59552221e-02, 1.37365120e-03,\n",
" 1.00179566e-02, 6.98126853e-04, 3.58139984e-02,\n",
" 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n",
" 4.75908853e-02, 5.48331346e-03, -6.41460950e-03,\n",
" -1.23906611e-02, 5.82688041e-02, -1.60842277e-02,\n",
" -2.95833423e-04, 6.97355811e-03, 2.48331465e-02,\n",
" -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n",
" 1.52637456e-02, 7.01832073e-03, 5.50601333e-02,\n",
" 4.35096538e-03, 2.36319732e-02, -1.38118947e-02,\n",
" -7.24233836e-02, 9.39742289e-03, -2.66901590e-02,\n",
" 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n",
" 3.08373477e-03, 7.12765753e-02, 7.13613164e-03,\n",
" 3.62721197e-02, -4.53250594e-02, 2.54001115e-02,\n",
" -2.54253373e-02, -1.23151275e-03, -1.34750446e-02,\n",
" -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n",
" -7.31003610e-03, 2.65329964e-02, -2.79857730e-03,\n",
" 4.20840643e-02, 3.20205763e-02, -1.19518824e-02,\n",
" -5.77116087e-02, 9.88688134e-03, 1.86814573e-02,\n",
" -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n",
" -3.33383717e-02, -5.52933291e-02, 5.64148054e-02,\n",
" 4.92153503e-02, 3.33690383e-02, -3.92963700e-02,\n",
" -6.91099390e-02, 3.79911740e-03, -1.74410697e-02,\n",
" -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n",
" 2.61192098e-02, -2.74193864e-02, 6.92490395e-03,\n",
" -4.64810384e-03, -8.99862905e-04, 1.02159111e-02,\n",
" -4.81114909e-02, 1.22787328e-02, -9.32844076e-03,\n",
" -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n",
" -1.60810221e-02, -2.20200215e-02, 2.32051890e-02,\n",
" -5.07331975e-02, -1.01248249e-02, 5.62567115e-02,\n",
" -2.60966737e-03, 9.27545596e-03, 5.32410555e-02,\n",
" 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n",
" -2.12969314e-02, 9.82244611e-02, -2.47648880e-02,\n",
" 7.06253499e-02, 8.71159416e-03, -2.73140483e-02,\n",
" 5.59884915e-03, -2.14829091e-02, -6.67077005e-02,\n",
" 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n",
" 3.77488993e-02, -1.37352264e-02, -2.85069812e-02,\n",
" 1.81708820e-02, -4.07746173e-02, -4.71230270e-03,\n",
" -1.59605164e-02, -1.25815195e-03, -6.59954594e-03,\n",
" -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n",
" 3.07933297e-02, 1.27626080e-02, -4.34489138e-02,\n",
" 9.91576444e-03, 1.25470785e-02, -8.67356583e-02,\n",
" 1.26097840e-03, 3.24709825e-02, -6.92409948e-02,\n",
" -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n",
" -2.35359464e-02, -2.95022167e-02, 2.88009271e-02,\n",
" -3.26618887e-02, 7.09307985e-03, 3.09435464e-03,\n",
" -5.09097055e-02, 3.54242921e-02, 5.37336655e-02,\n",
" 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n",
" 2.93767708e-03, 2.27999203e-02, 1.02668423e-02,\n",
" 3.35033536e-02, -8.28316063e-02, -4.17127199e-02,\n",
" -1.23034064e-02, 2.38543525e-02, -3.72257493e-02,\n",
" 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n",
" 5.74519299e-03, 2.89844945e-02, -2.21337453e-02,\n",
" 2.34603398e-02, 6.33142609e-03, -2.24104542e-02,\n",
" 1.47326495e-02, 1.98041964e-02, 3.05697713e-02,\n",
" -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n",
" 3.88860740e-02, -3.97440195e-02, -4.70216498e-02,\n",
" 1.02172708e-02, -3.37972888e-03, -8.54947045e-03,\n",
" 4.81354557e-02, 4.99849804e-02, 7.11378129e-03,\n",
" -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n",
" 5.24284095e-02, 2.16708388e-02, -4.00698110e-02,\n",
" 5.15380092e-02, -6.03203699e-02, -6.50304696e-03,\n",
" -1.03860423e-02, -7.47132823e-02, 3.59848235e-03,\n",
" -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n",
" -2.88047381e-02, -2.81904116e-02, 1.52729014e-02,\n",
" -1.55570190e-02, 1.34619148e-02, 2.34364290e-02,\n",
" 3.10326237e-02, -4.70464528e-02, -2.43550166e-02,\n",
" -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n",
" -5.30204549e-03, 5.52049950e-02, 4.50828709e-02,\n",
" -7.30262510e-03, 5.56289777e-02, -9.46066808e-03,\n",
" -3.37345451e-02, -1.87659152e-02, 3.57284099e-02,\n",
" 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n",
" 2.96422077e-04, 4.22447585e-02, 4.97253910e-02,\n",
" 6.03130311e-02, 1.32281650e-02, 2.35939436e-02,\n",
" -1.59284715e-02, 4.46444489e-02, -1.68315917e-02,\n",
" 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n",
" 8.99049267e-03, 4.74606343e-02, 6.70041004e-03,\n",
" -1.15184486e-03, 2.69540539e-03, -2.77549177e-02,\n",
" -1.33260442e-02, 2.60788556e-02, 4.35438640e-02,\n",
" -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n",
" 2.93240137e-02, 1.82274636e-03, -1.40310880e-02,\n",
" -1.91633645e-02, 1.18790809e-02, -4.65121269e-02,\n",
" -4.19883654e-02, -2.69681774e-02, -3.23035605e-02,\n",
" -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n",
" -2.55833156e-02, -5.73152229e-02, 3.30126472e-02,\n",
" -7.90146552e-03, -1.08651863e-02, 1.10474667e-02,\n",
" 3.03509296e-03, 1.55274626e-02, 1.05599947e-02,\n",
" -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n",
" 3.77239436e-02, 9.44003314e-02, -4.80610691e-02,\n",
" 4.73537892e-02, 3.40655483e-02, 7.88806472e-03,\n",
" -2.84915343e-02, 7.96849206e-02, 1.57442074e-02,\n",
" -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n",
" -1.72730908e-01, -8.72075930e-02, 2.86346450e-02,\n",
" 2.16962174e-02, -4.80199270e-02, 6.49317261e-03,\n",
" 1.67240556e-02, -2.56227311e-02, 2.19670162e-02,\n",
" -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n",
" -2.89566331e-02, 1.19498251e-02, -2.33849231e-02,\n",
" -2.69133616e-02, -1.46602485e-02, 1.18886270e-02,\n",
" 1.64973717e-02, -3.90495770e-02, -3.45575088e-03,\n",
" 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n",
" 2.10017413e-02, 2.74998210e-02, 3.03551817e-04,\n",
" -1.15796946e-01, -4.66962112e-03, -4.80118394e-02,\n",
" -3.55160870e-02, -4.72528581e-03, -4.29739058e-02,\n",
" -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n",
" 1.98413953e-02, -7.27679394e-03, 2.27117930e-02,\n",
" -2.59338003e-02, 4.31442596e-02, 1.07885078e-02,\n",
" -2.47129947e-02, -4.14506458e-02, 4.40958813e-02,\n",
" 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n",
" 1.13289580e-02, -5.57265691e-02, 1.71151303e-03,\n",
" -1.24145029e-02, -3.57853901e-03, -4.86295968e-02,\n",
" -5.14956787e-02, 4.79425713e-02, -3.24050151e-02,\n",
" 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n",
" 8.20766483e-03, -6.27530292e-02, -1.30661400e-02,\n",
" -3.52081768e-02, 4.83807474e-02, 9.81860235e-03,\n",
" 1.14539362e-01, -1.88471414e-02, 6.07751869e-02,\n",
" -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n",
" 2.64345529e-03, 3.07400171e-02, -4.31060083e-02,\n",
" -6.19985871e-02, 5.50477020e-03, 1.62547994e-02,\n",
" -8.26352183e-03, 7.56437238e-03, -4.79784003e-03,\n",
" 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n",
" 1.41595434e-02, 5.31185642e-02, 6.78585656e-03,\n",
" 6.56357184e-02, -5.06135784e-02, -3.05179805e-02,\n",
" 7.06539825e-02, -3.55644710e-02, -4.92612133e-03,\n",
" 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n",
" -1.86746120e-02, 2.49281265e-02, -4.92450967e-03,\n",
" 1.66887734e-02, 4.62210961e-02, 4.07794118e-02,\n",
" 2.52511259e-02, -2.83305068e-02, -2.78001893e-02,\n",
" -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n",
" 1.09969089e-02, 1.69700030e-02, -8.59475043e-03,\n",
" 4.70476560e-02, 3.64770554e-02, 2.09835749e-02,\n",
" 1.01236468e-02, 2.75151283e-02, 4.33402918e-02,\n",
" -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n",
" -6.10819347e-02, -2.86280159e-02, 4.68054451e-02,\n",
" 1.29892454e-02, -1.71940885e-02, -2.52429228e-02,\n",
" 3.86423096e-02, -1.35919163e-02, -5.27431667e-02,\n",
" 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n",
" 3.23252901e-02, 5.03172688e-02, -4.45654802e-02,\n",
" 2.90075876e-02, -1.35373492e-02, 6.78209821e-03,\n",
" -5.89249916e-02, 4.28890549e-02, -2.36034058e-02,\n",
" -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n",
" 1.45543357e-02, 1.07806427e-02, -6.06855676e-02,\n",
" -4.95252907e-02, 1.02004781e-02, 4.60227691e-02,\n",
" -1.08090881e-02, 4.42408510e-02, 4.15152796e-02,\n",
" 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n",
" -2.70090066e-02, 2.68773828e-02, -1.97812133e-02,\n",
" 2.25932393e-02, -1.33560598e-02, -1.50896851e-02,\n",
" -3.14053567e-03, 1.54051669e-02, 1.86488125e-02,\n",
" -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n",
" -2.37891261e-04, 1.84722953e-02, 3.60381305e-02,\n",
" -5.85213909e-03, 4.44293395e-02, -1.11264118e-03,\n",
" -4.79441285e-02, 3.46464328e-02, -2.53370814e-02,\n",
" -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n",
" -4.38152434e-04, 4.08602282e-02, -2.29470823e-02,\n",
" -1.89938806e-02, -1.52037974e-04, 1.05516789e-02,\n",
" 2.08601039e-02, -6.98119551e-02, 3.66246551e-02,\n",
" -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n",
" 6.51817098e-02, 4.29646857e-02, 2.56071109e-02,\n",
" -3.28080021e-02, 1.20534413e-02, 3.56224040e-03,\n",
" -1.01593453e-02, -1.96505673e-04, 4.33485657e-02,\n",
" -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n",
" -1.40319867e-02, -3.63940969e-02, -3.09983976e-02,\n",
" -4.19548260e-33, 7.11604580e-02, 4.78382297e-02,\n",
" 1.89297704e-03, -1.60731785e-02, 2.53787991e-02,\n",
" -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n",
" 1.68679946e-03, 1.92391127e-02, -2.20667192e-04,\n",
" 1.32907527e-02, 5.99487219e-03, 2.75156219e-02,\n",
" -5.06000873e-03, -3.58465910e-02, 8.20948277e-03,\n",
" -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n",
" -1.09853260e-01, -3.66037302e-02, 3.55480015e-02,\n",
" 4.23291475e-02, 1.48312682e-02, 5.68749309e-02,\n",
" 3.57767567e-02, 1.40728084e-02, -4.00471613e-02,\n",
" 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n",
" 1.24238459e-02, 1.20237898e-02, -7.69484974e-03,\n",
" -3.30727436e-02, -1.45808076e-02, 3.43246050e-02,\n",
" 3.21143419e-02, -4.96741422e-02, -5.27968369e-02,\n",
" 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n",
" 2.77636852e-02, 5.90689071e-02, -2.61273161e-02,\n",
" -6.95008039e-02, -3.15576978e-02, -5.62214339e-03,\n",
" -7.93884136e-03, -3.62196900e-02, -8.26047733e-03,\n",
" 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n",
" -2.52235290e-02, -3.88054736e-02, -2.00710595e-02,\n",
" 1.50789914e-03, -5.51338419e-02, -8.35673045e-03,\n",
" -1.61523875e-02, -8.79513845e-02, -5.28004877e-02,\n",
" -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n",
" 4.44932319e-02, 8.69598426e-03, -1.14432694e-02,\n",
" 4.47212979e-02, 2.70624813e-02, -3.86100151e-02,\n",
" -3.07358261e-02, 2.75634117e-02, 1.48464069e-02,\n",
" -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n",
" 8.05836394e-02, -1.69498641e-02, 4.44465503e-02,\n",
" -2.09145956e-02, -3.37407738e-02, 3.85780074e-02,\n",
" -7.44559616e-02, 1.17512364e-02, 1.01964204e-02,\n",
" -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n",
" 2.54592765e-02, -1.46158040e-02, 5.46646416e-02,\n",
" 1.43051194e-03, 2.99116820e-02, 2.24273186e-02,\n",
" -5.79927117e-03, -1.33864526e-02, -2.52460372e-02,\n",
" -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n",
" 3.38429734e-02, -2.11539529e-02, 7.17787817e-02,\n",
" -7.78904185e-02, -4.04084288e-02, 4.90567498e-02,\n",
" -2.61603445e-02, 1.97753590e-02, 4.97209951e-02,\n",
" -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n",
" 2.68440694e-02, 3.29160057e-02, -8.24410375e-03,\n",
" -1.33646047e-02, -6.22822754e-02, -1.13362661e-02,\n",
" -3.79339382e-02, -6.56360280e-05, -1.08087100e-02,\n",
" 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n",
" -2.54666172e-02, -3.05371322e-02, -1.53249800e-02,\n",
" -9.87035502e-03, 1.95337094e-07, -1.76476724e-02,\n",
" 5.71432859e-02, -2.49180794e-02, 5.85253723e-02,\n",
" 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n",
" 4.07801419e-02, 4.13940698e-02, 2.55707726e-02,\n",
" 2.18985360e-02, -3.04434425e-03, -3.77355106e-02,\n",
" -6.24866784e-02, -1.17468778e-02, -4.82194684e-02,\n",
" -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n",
" -2.48471629e-02, 8.05181568e-04, -4.85844910e-03,\n",
" -5.16015477e-03, 7.53483502e-03, -9.46175400e-03,\n",
" -2.39896346e-02, -3.14654633e-02, 1.50111094e-02,\n",
" -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n",
" 3.08971256e-02, 1.72299352e-02, 5.93419448e-02,\n",
" -5.74274361e-02, -8.16087723e-02, -4.80572283e-02,\n",
" -2.68838424e-02, -1.96331330e-02, -9.15831141e-03,\n",
" 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n",
" 9.21937004e-02, 1.37132118e-02, -1.19096776e-02,\n",
" -4.09874134e-02, 3.37628126e-02, -4.64820908e-03,\n",
" -2.50304434e-02, 6.25852346e-02, -1.24449311e-02,\n",
" 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n",
" 5.08641489e-02, 2.53822445e-03, 5.25634140e-02,\n",
" 1.14882430e-02, 5.01894541e-02, -3.55215147e-02,\n",
" -3.31749097e-02, -3.02003417e-03, -5.36288768e-02,\n",
" -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n",
" 9.56887701e-35, 2.55127084e-02, -1.44770980e-04,\n",
" 1.96710341e-02, -1.33620016e-02, -1.51910949e-02,\n",
" -3.28495577e-02, -1.52465852e-03, -2.65272055e-02,\n",
" -4.35708016e-02, -1.75950192e-02, -2.20594816e-02]], dtype=float32)"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"svec= np.array(vec).reshape(1,-1)\n",
"print(svec.shape)\n",
"svec"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "fef30d70-6958-4259-abb6-09f8c1870a2b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.7731663 0.79433584]] [[330 331]]\n"
]
}
],
"source": [
"distances, I= index.search(svec, k=2)\n",
"print(distances, I)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "eb00598c-9799-4697-b2a3-356bb5aae0f1",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" \n",
" | \n",
" desease_condition | \n",
" text | \n",
" \n",
" \n",
" \n",
" \n",
" 330 | \n",
" ['alzheimer disease', 'dementia', 'brain disea... | \n",
" nct_id: NCT02164643\\nsummary: A Multicenter na... | \n",
" \n",
" \n",
" 331 | \n",
" ['alzheimer disease', 'dementia', 'brain disea... | \n",
" nct_id: NCT02164643\\nsummary: A Multicenter na... | \n",
" \n",
" \n",
" \n",
" "
],
"text/plain": [
" desease_condition \\\n",
"330 ['alzheimer disease', 'dementia', 'brain disea... \n",
"331 ['alzheimer disease', 'dementia', 'brain disea... \n",
"\n",
" text \n",
"330 nct_id: NCT02164643\\nsummary: A Multicenter na... \n",
"331 nct_id: NCT02164643\\nsummary: A Multicenter na... "
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df2= df.iloc[I[0]]\n",
"df2"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "af5bf8e2-43b6-47af-affa-5111789371ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'nct_id: NCT02164643\\nsummary: A Multicenter national longitudinal cohort study including at least 800 individuals consecutively recruited from French Research Memory Centers and followed-up over 24 month and included in Memento.\\nintervention_type: Drug\\nintervention_name: Florbetapir (18F)\\nintervention_description: nan\\nkeywords: [\"Alzheimer\\'s disease\", \\'Mild Cognitive Impairment\\']'"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df2.iloc[1].text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3899f81-e120-475c-97ed-080cb7f46510",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|