{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9db3b813-22dc-4209-86d2-42e935f5f5dd", "metadata": {}, "outputs": [], "source": [ "from langchain_community.document_loaders.csv_loader import CSVLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "import pandas as pd\n", "import langchain\n", "import os\n", "import openai\n", "import ast\n", "from langchain import OpenAI\n", "from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain.document_loaders import UnstructuredURLLoader\n", "from langchain.embeddings import OpenAIEmbeddings\n", "from langchain.vectorstores import FAISS" ] }, { "cell_type": "code", "execution_count": 9, "id": "fb48dccd-37e5-484a-a2fc-c482839b9ed9", "metadata": {}, "outputs": [], "source": [ "# loader = CSVLoader(file_path=\"trials/brief_summaries.csv\")\n", "# data = loader.load()" ] }, { "cell_type": "code", "execution_count": 10, "id": "cdcc3107-e6a8-47ca-bd89-e381fbcf9b9e", "metadata": {}, "outputs": [], "source": [ "# df= pd.read_csv(\"trials/brief_summaries.txt\", delimiter=\"|\")\n", "# df.shape" ] }, { "cell_type": "code", "execution_count": 11, "id": "c6a94c14-b197-4eff-9941-1ca52069cd5c", "metadata": {}, "outputs": [], "source": [ "# df.head(20)" ] }, { "cell_type": "code", "execution_count": 2, "id": "95c000ff-bf0c-4489-8643-93a238db41dc", "metadata": {}, "outputs": [], "source": [ "os.environ['OPENAI_API_KEY']=\"sk-proj-CG2E98bSWs53X2eWO0Z4T3BlbkFJLm7H1vfkbua0zP548CKQ\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "e2e03936-fcce-4287-bfe6-e31f1b69f693", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/aldan.creo/miniconda3/envs/hackupc/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `OpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import OpenAI`.\n", " warn_deprecated(\n" ] } ], "source": [ "llm= OpenAI(\n", " temperature=0.6, max_tokens=500\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "aede3d75-2441-44f6-b3cf-2f86f050da24", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(440517, 2)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_conditiontext
0['marijuana abuse', 'substance-related disorde...nct_id: NCT03055377\\nsummary: This is a 12-wee...
1['marijuana abuse', 'substance-related disorde...nct_id: NCT03055377\\nsummary: This is a 12-wee...
2['tuberculosis', 'latent tuberculosis', 'infec...nct_id: NCT03042754\\nsummary: Early diagnosis ...
3['heart failure', 'heart diseases', 'cardiovas...nct_id: NCT03035123\\nsummary: The EduStra-HF s...
4['lymphoma', 'neoplasms by histologic type', '...nct_id: NCT02272751\\nsummary: This study will ...
\n", "
" ], "text/plain": [ " desease_condition \\\n", "0 ['marijuana abuse', 'substance-related disorde... \n", "1 ['marijuana abuse', 'substance-related disorde... \n", "2 ['tuberculosis', 'latent tuberculosis', 'infec... \n", "3 ['heart failure', 'heart diseases', 'cardiovas... \n", "4 ['lymphoma', 'neoplasms by histologic type', '... \n", "\n", " text \n", "0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n", "1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n", "2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n", "3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n", "4 nct_id: NCT02272751\\nsummary: This study will ... " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials= pd.read_csv(\"clinical_trials.csv\")\n", "print(df_trials.shape)\n", "df_trials.head()" ] }, { "cell_type": "code", "execution_count": 15, "id": "65589fc4-4f55-4b72-9935-c3b163fde0e2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_condition
0['marijuana abuse', 'substance-related disorde...
1['marijuana abuse', 'substance-related disorde...
2['tuberculosis', 'latent tuberculosis', 'infec...
3['heart failure', 'heart diseases', 'cardiovas...
4['lymphoma', 'neoplasms by histologic type', '...
\n", "
" ], "text/plain": [ " desease_condition\n", "0 ['marijuana abuse', 'substance-related disorde...\n", "1 ['marijuana abuse', 'substance-related disorde...\n", "2 ['tuberculosis', 'latent tuberculosis', 'infec...\n", "3 ['heart failure', 'heart diseases', 'cardiovas...\n", "4 ['lymphoma', 'neoplasms by histologic type', '..." ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials_filtered=df_trials[['desease_condition']]\n", "df_trials_filtered.head()" ] }, { "cell_type": "code", "execution_count": 16, "id": "88e28056-e340-416c-a1a9-4a6c29556dc7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"['marijuana abuse', 'substance-related disorders', 'chemically-induced disorders', 'mental disorders']\"" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials_filtered['desease_condition'].iloc[0]" ] }, { "cell_type": "code", "execution_count": 17, "id": "5bd8f876-0480-40a5-a32f-ca7ec137a70f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\ariji\\AppData\\Local\\Temp\\ipykernel_22340\\16068817.py:4: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_condition
0marijuana abuse, substance-related disorders, ...
1marijuana abuse, substance-related disorders, ...
2tuberculosis, latent tuberculosis, infections,...
3heart failure, heart diseases, cardiovascular ...
4lymphoma, neoplasms by histologic type, neopla...
......
440512obesity, overweight, overnutrition, nutrition ...
440513obesity, overweight, overnutrition, nutrition ...
440514obesity, overweight, overnutrition, nutrition ...
440515autistic disorder, autism spectrum disorder, c...
440516autistic disorder, autism spectrum disorder, c...
\n", "

440517 rows × 1 columns

\n", "
" ], "text/plain": [ " desease_condition\n", "0 marijuana abuse, substance-related disorders, ...\n", "1 marijuana abuse, substance-related disorders, ...\n", "2 tuberculosis, latent tuberculosis, infections,...\n", "3 heart failure, heart diseases, cardiovascular ...\n", "4 lymphoma, neoplasms by histologic type, neopla...\n", "... ...\n", "440512 obesity, overweight, overnutrition, nutrition ...\n", "440513 obesity, overweight, overnutrition, nutrition ...\n", "440514 obesity, overweight, overnutrition, nutrition ...\n", "440515 autistic disorder, autism spectrum disorder, c...\n", "440516 autistic disorder, autism spectrum disorder, c...\n", "\n", "[440517 rows x 1 columns]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def list_to_string(disease_list):\n", " disease_list= ast.literal_eval(disease_list)\n", " return ', '.join(disease_list)\n", "df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n", "df_trials_filtered" ] }, { "cell_type": "code", "execution_count": 18, "id": "bbbb22e6-4883-4869-8ccc-95696bc67b1b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 marijuana abuse, substance-related disorders, ...\n", "1 marijuana abuse, substance-related disorders, ...\n", "2 tuberculosis, latent tuberculosis, infections,...\n", "3 heart failure, heart diseases, cardiovascular ...\n", "4 lymphoma, neoplasms by histologic type, neopla...\n", "Name: desease_condition, dtype: object" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials_filtered['desease_condition'].head()" ] }, { "cell_type": "code", "execution_count": 19, "id": "7f4e7ceb-8bfd-4294-a850-8935f88b6555", "metadata": {}, "outputs": [], "source": [ "df_trials_filtered.to_csv(\"diseases.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 20, "id": "af1c5c2b-24a0-44a1-9e5d-7ee89ca4cccf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "440517" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loader= CSVLoader(file_path=\"./diseases.csv\", encoding=\"utf-8\")\n", "data = loader.load()\n", "len(data)" ] }, { "cell_type": "code", "execution_count": null, "id": "cab89218-41ca-4048-886d-bc2c1c9b30bc", "metadata": {}, "outputs": [], "source": [ "embeddings = OpenAIEmbeddings()\n", "vectorstore = FAISS.from_documents(data, embeddings)" ] }, { "cell_type": "code", "execution_count": null, "id": "225ade4a-d004-44cc-a5ff-22ce2bfcac32", "metadata": {}, "outputs": [], "source": [ "file_path= \"vector_index.pkl\"\n", "with open(file_path, \"wb\") as f:\n", " pickle.dump(vectorstore, f)" ] }, { "cell_type": "code", "execution_count": 98, "id": "11912a93-ad02-41cb-8bce-2750c947fa74", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(440517, 2)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_conditiontext
0['marijuana abuse', 'substance-related disorde...nct_id: NCT03055377\\nsummary: This is a 12-wee...
1['marijuana abuse', 'substance-related disorde...nct_id: NCT03055377\\nsummary: This is a 12-wee...
2['tuberculosis', 'latent tuberculosis', 'infec...nct_id: NCT03042754\\nsummary: Early diagnosis ...
3['heart failure', 'heart diseases', 'cardiovas...nct_id: NCT03035123\\nsummary: The EduStra-HF s...
4['lymphoma', 'neoplasms by histologic type', '...nct_id: NCT02272751\\nsummary: This study will ...
\n", "
" ], "text/plain": [ " desease_condition \\\n", "0 ['marijuana abuse', 'substance-related disorde... \n", "1 ['marijuana abuse', 'substance-related disorde... \n", "2 ['tuberculosis', 'latent tuberculosis', 'infec... \n", "3 ['heart failure', 'heart diseases', 'cardiovas... \n", "4 ['lymphoma', 'neoplasms by histologic type', '... \n", "\n", " text \n", "0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n", "1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n", "2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n", "3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n", "4 nct_id: NCT02272751\\nsummary: This study will ... " ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials= pd.read_csv(\"clinical_trials.csv\")\n", "print(df_trials.shape)\n", "df_trials.head()" ] }, { "cell_type": "code", "execution_count": 99, "id": "a8113876-f38e-4c7a-891e-17dc51e2bacf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"nct_id: NCT03055377\\nsummary: This is a 12-week randomized, placebo-controlled trial of N-acetylcysteine for cannabis use disorder (CUD) in youth (N=192). Participants will be randomized to double-blind NAC or PBO, yielding two equally-allocated treatment groups. All participants will receive brief weekly cannabis cessation counseling and medication management. The primary efficacy outcome will be the proportion of negative urine cannabinoid tests during the 12-week active treatment, compared between groups.\\nintervention_type: Drug\\nintervention_name: N-acetyl cysteine\\nintervention_description: N-acetylcysteine 1200 mg twice daily for 12 weeks (administered orally)\\nkeywords: ['cannabis', 'marijuana', 'youth', 'adolescent', 'pharmacotherapy', 'medication', 'n-acetylcysteine', 'trial']\"" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials.iloc[0].text" ] }, { "cell_type": "code", "execution_count": 100, "id": "c9ace03c-cf71-4605-adeb-9c6ee10e158a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1000, 2)" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df= df_trials[:1000]\n", "df.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "85a985c5-845b-464e-9527-e806b6883741", "metadata": {}, "outputs": [], "source": [ "embeddings= OpenAIEmbeddings()\n", "vectorindex_openapi= FAISS.from" ] }, { "cell_type": "code", "execution_count": null, "id": "25b31e55-2961-474d-92d8-5963f2c6bf84", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "918c078c-46fe-4d7b-9748-88c52a5b004a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "36e83202-97ad-425d-95ae-075a1e26a34e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "5d705875-5dd7-4c71-8d94-99c101020ac0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "50c15b4d-9f65-4385-9aa6-3f782bf57775", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 6, "id": "f2818abe-1a43-4d7d-92a7-7562812bf43d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_condition
0marijuana abuse, substance-related disorders, ...
1marijuana abuse, substance-related disorders, ...
2tuberculosis, latent tuberculosis, infections,...
3heart failure, heart diseases, cardiovascular ...
4lymphoma, neoplasms by histologic type, neopla...
......
440512obesity, overweight, overnutrition, nutrition ...
440513obesity, overweight, overnutrition, nutrition ...
440514obesity, overweight, overnutrition, nutrition ...
440515autistic disorder, autism spectrum disorder, c...
440516autistic disorder, autism spectrum disorder, c...
\n", "

440517 rows × 1 columns

\n", "
" ], "text/plain": [ " desease_condition\n", "0 marijuana abuse, substance-related disorders, ...\n", "1 marijuana abuse, substance-related disorders, ...\n", "2 tuberculosis, latent tuberculosis, infections,...\n", "3 heart failure, heart diseases, cardiovascular ...\n", "4 lymphoma, neoplasms by histologic type, neopla...\n", "... ...\n", "440512 obesity, overweight, overnutrition, nutrition ...\n", "440513 obesity, overweight, overnutrition, nutrition ...\n", "440514 obesity, overweight, overnutrition, nutrition ...\n", "440515 autistic disorder, autism spectrum disorder, c...\n", "440516 autistic disorder, autism spectrum disorder, c...\n", "\n", "[440517 rows x 1 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trials_filtered = pd.read_csv(\"diseases.csv\")\n", "df_trials_filtered" ] }, { "cell_type": "code", "execution_count": 49, "id": "c89e3cf6-a376-4029-9c04-0f5e664a2237", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(100, 1)" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2= df_trials_filtered[:100]\n", "df2.shape" ] }, { "cell_type": "code", "execution_count": 7, "id": "c5012bcf-3e25-4f21-a29c-6bdbdafbb8c7", "metadata": {}, "outputs": [], "source": [ "from openai import OpenAI\n", "client= OpenAI()" ] }, { "cell_type": "code", "execution_count": 8, "id": "40a480bd-6754-40b6-870c-42d10ce9a960", "metadata": {}, "outputs": [], "source": [ "def get_embeddings(text):\n", " response= client.embeddings.create(\n", " input= text,\n", " dimensions= 128,\n", " model= \"text-embedding-3-small\"\n", " )\n", " return response.data[0].embedding" ] }, { "cell_type": "code", "execution_count": 9, "id": "ef6d6b62-de0b-4bc6-a6eb-847ab8e99da5", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'sentence_transformers'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m:1\u001b[0m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sentence_transformers'" ] } ], "source": [ "%%time\n", "from sentence_transformers import SentenceTransformer\n", "\n", "encoder= SentenceTransformer(\"all-mpnet-base-v2\")\n", "vectors= encoder.encode(df2.desease_condition)\n", "vectors.shape" ] }, { "cell_type": "code", "execution_count": 66, "id": "7966d754-56d7-4555-a6c6-6a13772fb000", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: total: 62.5 ms\n", "Wall time: 24.7 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":1: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_conditionembeddings
0marijuana abuse, substance-related disorders, ...[-0.05811865255236626, -0.023393018171191216, ...
1marijuana abuse, substance-related disorders, ...[-0.05811865255236626, -0.023393018171191216, ...
2tuberculosis, latent tuberculosis, infections,...[-0.03460180386900902, -0.084668830037117, 0.2...
3heart failure, heart diseases, cardiovascular ...[-0.08236236125230789, -0.1235777735710144, 0....
4lymphoma, neoplasms by histologic type, neopla...[-0.1227850392460823, 0.07155642658472061, 0.1...
\n", "
" ], "text/plain": [ " desease_condition \\\n", "0 marijuana abuse, substance-related disorders, ... \n", "1 marijuana abuse, substance-related disorders, ... \n", "2 tuberculosis, latent tuberculosis, infections,... \n", "3 heart failure, heart diseases, cardiovascular ... \n", "4 lymphoma, neoplasms by histologic type, neopla... \n", "\n", " embeddings \n", "0 [-0.05811865255236626, -0.023393018171191216, ... \n", "1 [-0.05811865255236626, -0.023393018171191216, ... \n", "2 [-0.03460180386900902, -0.084668830037117, 0.2... \n", "3 [-0.08236236125230789, -0.1235777735710144, 0.... \n", "4 [-0.1227850392460823, 0.07155642658472061, 0.1... " ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "df2['embeddings']= df2['desease_condition'].apply(get_embeddings)\n", "df2.head()" ] }, { "cell_type": "code", "execution_count": 65, "id": "2711980a-d1c0-441e-ae9a-531500b7b7cd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(128,)" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "np.array(df2.iloc[0].embeddings).shape" ] }, { "cell_type": "code", "execution_count": 23, "id": "13e40184-999d-4363-be44-18ba9fb6745a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\ariji\\Anaconda3\\envs\\python310\\lib\\site-packages\\huggingface_hub\\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Input \u001b[1;32mIn [23]\u001b[0m, in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msentence_transformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SentenceTransformer\n\u001b[0;32m 3\u001b[0m encoder\u001b[38;5;241m=\u001b[39m SentenceTransformer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mall-mpnet-base-v2\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 4\u001b[0m vectors\u001b[38;5;241m=\u001b[39m \u001b[43mencoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf_trials_filtered\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdesease_condition\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m vectors\u001b[38;5;241m.\u001b[39mshape\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:371\u001b[0m, in \u001b[0;36mSentenceTransformer.encode\u001b[1;34m(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings)\u001b[0m\n\u001b[0;32m 368\u001b[0m features\u001b[38;5;241m.\u001b[39mupdate(extra_features)\n\u001b[0;32m 370\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m--> 371\u001b[0m out_features \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 372\u001b[0m out_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence_embedding\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m truncate_embeddings(\n\u001b[0;32m 373\u001b[0m out_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence_embedding\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtruncate_dim\n\u001b[0;32m 374\u001b[0m )\n\u001b[0;32m 376\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_value \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_embeddings\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 138\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 139\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 140\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\sentence_transformers\\models\\Transformer.py:98\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, features)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m features:\n\u001b[0;32m 96\u001b[0m trans_features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_type_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m---> 98\u001b[0m output_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_model(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mtrans_features, return_dict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 99\u001b[0m output_tokens \u001b[38;5;241m=\u001b[39m output_states[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 101\u001b[0m features\u001b[38;5;241m.\u001b[39mupdate({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_embeddings\u001b[39m\u001b[38;5;124m\"\u001b[39m: output_tokens, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m: features[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]})\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:548\u001b[0m, in \u001b[0;36mMPNetModel.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m 546\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[0;32m 547\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(input_ids\u001b[38;5;241m=\u001b[39minput_ids, position_ids\u001b[38;5;241m=\u001b[39mposition_ids, inputs_embeds\u001b[38;5;241m=\u001b[39minputs_embeds)\n\u001b[1;32m--> 548\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 549\u001b[0m \u001b[43m \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 550\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 551\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 552\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 553\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 554\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 555\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 556\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 557\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:338\u001b[0m, in \u001b[0;36mMPNetEncoder.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[0;32m 336\u001b[0m all_hidden_states \u001b[38;5;241m=\u001b[39m all_hidden_states \u001b[38;5;241m+\u001b[39m (hidden_states,)\n\u001b[1;32m--> 338\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m layer_module(\n\u001b[0;32m 339\u001b[0m hidden_states,\n\u001b[0;32m 340\u001b[0m attention_mask,\n\u001b[0;32m 341\u001b[0m head_mask[i],\n\u001b[0;32m 342\u001b[0m position_bias,\n\u001b[0;32m 343\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[0;32m 344\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[0;32m 345\u001b[0m )\n\u001b[0;32m 346\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 348\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:308\u001b[0m, in \u001b[0;36mMPNetLayer.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, position_bias, output_attentions, **kwargs)\u001b[0m\n\u001b[0;32m 305\u001b[0m outputs \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m1\u001b[39m:] \u001b[38;5;66;03m# add self attentions if we output attention weights\u001b[39;00m\n\u001b[0;32m 307\u001b[0m intermediate_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mintermediate(attention_output)\n\u001b[1;32m--> 308\u001b[0m layer_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\u001b[43m(\u001b[49m\u001b[43mintermediate_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_output\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 309\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (layer_output,) \u001b[38;5;241m+\u001b[39m outputs\n\u001b[0;32m 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\transformers\\models\\mpnet\\modeling_mpnet.py:275\u001b[0m, in \u001b[0;36mMPNetOutput.forward\u001b[1;34m(self, hidden_states, input_tensor)\u001b[0m\n\u001b[0;32m 274\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor, input_tensor: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m--> 275\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdense\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 276\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(hidden_states)\n\u001b[0;32m 277\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLayerNorm(hidden_states \u001b[38;5;241m+\u001b[39m input_tensor)\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[1;32m~\\Anaconda3\\envs\\python310\\lib\\site-packages\\torch\\nn\\modules\\linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "from sentence_transformers import SentenceTransformer\n", "\n", "encoder= SentenceTransformer(\"all-mpnet-base-v2\")\n", "vectors= encoder.encode(df2.desease_condition)\n", "vectors.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "5f626382-75f1-4864-8574-bbcd183b926a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4e064857-2480-4ac2-911a-1c5a2ffb62d6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 102, "id": "c0a719a6-ee06-4f6d-a64b-8fc74c7afbca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0.01245328 0.07695839 0.01802037 ... 0.0093326 -0.03474615\n", " -0.02757339]\n", " [ 0.01846956 0.08282936 0.01921537 ... 0.01068991 -0.03402989\n", " -0.02075185]\n", " [ 0.01822413 -0.05593034 -0.00288358 ... 0.03525289 -0.05427228\n", " -0.03371295]\n", " ...\n", " [ 0.04031958 0.0040107 0.02156032 ... 0.01568209 -0.04320977\n", " -0.02990234]\n", " [ 0.0399131 0.00027251 0.02207735 ... 0.01440835 -0.04246744\n", " -0.02869584]\n", " [ 0.03773859 -0.00315346 0.0207725 ... 0.01205995 -0.04628598\n", " -0.02870333]]\n" ] } ], "source": [ "print(vectors)" ] }, { "cell_type": "code", "execution_count": 103, "id": "f21c08d5-596a-4fc4-ba72-2de8624e6c50", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "768" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dim= vectors.shape[1]\n", "dim " ] }, { "cell_type": "code", "execution_count": 106, "id": "8ed985f4-9402-431f-bfba-1236ba16b895", "metadata": {}, "outputs": [ { "data": { "text/plain": [ " >" ] }, "execution_count": 106, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import faiss\n", "\n", "index= faiss.IndexFlatL2(dim)\n", "index" ] }, { "cell_type": "code", "execution_count": 107, "id": "c0f1dde8-c173-45ba-957b-f4622fe6f0ee", "metadata": {}, "outputs": [], "source": [ "index.add(vectors)" ] }, { "cell_type": "code", "execution_count": 108, "id": "71dad860-1f19-4166-a309-c9ce15f24792", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(768,)\n" ] }, { "data": { "text/plain": [ "array([-8.82369652e-03, 8.50650743e-02, 2.08267733e-03, 6.77651772e-03,\n", " -2.86661759e-02, -8.71188380e-03, 6.99447095e-02, 5.04214764e-02,\n", " 3.58386151e-02, 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n", " 2.27009598e-03, 2.10810862e-02, 2.66138893e-02, 1.90623086e-02,\n", " 4.44708914e-02, 2.96202525e-02, 5.42085357e-02, -2.34859088e-03,\n", " -9.87798795e-02, -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n", " 5.31156994e-02, -1.37044629e-02, 2.92537250e-02, -2.61334293e-02,\n", " -1.21854078e-04, -2.36813519e-02, -3.81283499e-02, -1.79494768e-02,\n", " -6.29265187e-03, 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n", " -1.28973828e-04, 4.01791967e-02, 4.21229303e-02, -8.72302521e-03,\n", " 7.44823692e-03, 7.68032745e-02, 6.50246907e-03, 3.40298638e-02,\n", " -1.80711355e-02, -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n", " -3.34868580e-02, 1.05205458e-02, 2.08975170e-02, 4.36686277e-02,\n", " 3.47612537e-02, -4.99080680e-02, 4.44446988e-02, 5.57280704e-03,\n", " -2.31200755e-02, -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n", " 4.67903316e-02, 1.91651955e-02, -5.12171052e-02, 2.46807020e-02,\n", " -5.52081019e-02, 3.50596346e-02, -7.01438356e-03, -3.36519890e-02,\n", " -1.41502097e-02, -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n", " -1.30138136e-02, 5.91567121e-02, 3.37168351e-02, 3.01467292e-02,\n", " -4.59552221e-02, 1.37365120e-03, 1.00179566e-02, 6.98126853e-04,\n", " 3.58139984e-02, 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n", " 4.75908853e-02, 5.48331346e-03, -6.41460950e-03, -1.23906611e-02,\n", " 5.82688041e-02, -1.60842277e-02, -2.95833423e-04, 6.97355811e-03,\n", " 2.48331465e-02, -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n", " 1.52637456e-02, 7.01832073e-03, 5.50601333e-02, 4.35096538e-03,\n", " 2.36319732e-02, -1.38118947e-02, -7.24233836e-02, 9.39742289e-03,\n", " -2.66901590e-02, 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n", " 3.08373477e-03, 7.12765753e-02, 7.13613164e-03, 3.62721197e-02,\n", " -4.53250594e-02, 2.54001115e-02, -2.54253373e-02, -1.23151275e-03,\n", " -1.34750446e-02, -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n", " -7.31003610e-03, 2.65329964e-02, -2.79857730e-03, 4.20840643e-02,\n", " 3.20205763e-02, -1.19518824e-02, -5.77116087e-02, 9.88688134e-03,\n", " 1.86814573e-02, -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n", " -3.33383717e-02, -5.52933291e-02, 5.64148054e-02, 4.92153503e-02,\n", " 3.33690383e-02, -3.92963700e-02, -6.91099390e-02, 3.79911740e-03,\n", " -1.74410697e-02, -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n", " 2.61192098e-02, -2.74193864e-02, 6.92490395e-03, -4.64810384e-03,\n", " -8.99862905e-04, 1.02159111e-02, -4.81114909e-02, 1.22787328e-02,\n", " -9.32844076e-03, -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n", " -1.60810221e-02, -2.20200215e-02, 2.32051890e-02, -5.07331975e-02,\n", " -1.01248249e-02, 5.62567115e-02, -2.60966737e-03, 9.27545596e-03,\n", " 5.32410555e-02, 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n", " -2.12969314e-02, 9.82244611e-02, -2.47648880e-02, 7.06253499e-02,\n", " 8.71159416e-03, -2.73140483e-02, 5.59884915e-03, -2.14829091e-02,\n", " -6.67077005e-02, 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n", " 3.77488993e-02, -1.37352264e-02, -2.85069812e-02, 1.81708820e-02,\n", " -4.07746173e-02, -4.71230270e-03, -1.59605164e-02, -1.25815195e-03,\n", " -6.59954594e-03, -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n", " 3.07933297e-02, 1.27626080e-02, -4.34489138e-02, 9.91576444e-03,\n", " 1.25470785e-02, -8.67356583e-02, 1.26097840e-03, 3.24709825e-02,\n", " -6.92409948e-02, -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n", " -2.35359464e-02, -2.95022167e-02, 2.88009271e-02, -3.26618887e-02,\n", " 7.09307985e-03, 3.09435464e-03, -5.09097055e-02, 3.54242921e-02,\n", " 5.37336655e-02, 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n", " 2.93767708e-03, 2.27999203e-02, 1.02668423e-02, 3.35033536e-02,\n", " -8.28316063e-02, -4.17127199e-02, -1.23034064e-02, 2.38543525e-02,\n", " -3.72257493e-02, 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n", " 5.74519299e-03, 2.89844945e-02, -2.21337453e-02, 2.34603398e-02,\n", " 6.33142609e-03, -2.24104542e-02, 1.47326495e-02, 1.98041964e-02,\n", " 3.05697713e-02, -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n", " 3.88860740e-02, -3.97440195e-02, -4.70216498e-02, 1.02172708e-02,\n", " -3.37972888e-03, -8.54947045e-03, 4.81354557e-02, 4.99849804e-02,\n", " 7.11378129e-03, -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n", " 5.24284095e-02, 2.16708388e-02, -4.00698110e-02, 5.15380092e-02,\n", " -6.03203699e-02, -6.50304696e-03, -1.03860423e-02, -7.47132823e-02,\n", " 3.59848235e-03, -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n", " -2.88047381e-02, -2.81904116e-02, 1.52729014e-02, -1.55570190e-02,\n", " 1.34619148e-02, 2.34364290e-02, 3.10326237e-02, -4.70464528e-02,\n", " -2.43550166e-02, -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n", " -5.30204549e-03, 5.52049950e-02, 4.50828709e-02, -7.30262510e-03,\n", " 5.56289777e-02, -9.46066808e-03, -3.37345451e-02, -1.87659152e-02,\n", " 3.57284099e-02, 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n", " 2.96422077e-04, 4.22447585e-02, 4.97253910e-02, 6.03130311e-02,\n", " 1.32281650e-02, 2.35939436e-02, -1.59284715e-02, 4.46444489e-02,\n", " -1.68315917e-02, 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n", " 8.99049267e-03, 4.74606343e-02, 6.70041004e-03, -1.15184486e-03,\n", " 2.69540539e-03, -2.77549177e-02, -1.33260442e-02, 2.60788556e-02,\n", " 4.35438640e-02, -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n", " 2.93240137e-02, 1.82274636e-03, -1.40310880e-02, -1.91633645e-02,\n", " 1.18790809e-02, -4.65121269e-02, -4.19883654e-02, -2.69681774e-02,\n", " -3.23035605e-02, -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n", " -2.55833156e-02, -5.73152229e-02, 3.30126472e-02, -7.90146552e-03,\n", " -1.08651863e-02, 1.10474667e-02, 3.03509296e-03, 1.55274626e-02,\n", " 1.05599947e-02, -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n", " 3.77239436e-02, 9.44003314e-02, -4.80610691e-02, 4.73537892e-02,\n", " 3.40655483e-02, 7.88806472e-03, -2.84915343e-02, 7.96849206e-02,\n", " 1.57442074e-02, -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n", " -1.72730908e-01, -8.72075930e-02, 2.86346450e-02, 2.16962174e-02,\n", " -4.80199270e-02, 6.49317261e-03, 1.67240556e-02, -2.56227311e-02,\n", " 2.19670162e-02, -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n", " -2.89566331e-02, 1.19498251e-02, -2.33849231e-02, -2.69133616e-02,\n", " -1.46602485e-02, 1.18886270e-02, 1.64973717e-02, -3.90495770e-02,\n", " -3.45575088e-03, 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n", " 2.10017413e-02, 2.74998210e-02, 3.03551817e-04, -1.15796946e-01,\n", " -4.66962112e-03, -4.80118394e-02, -3.55160870e-02, -4.72528581e-03,\n", " -4.29739058e-02, -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n", " 1.98413953e-02, -7.27679394e-03, 2.27117930e-02, -2.59338003e-02,\n", " 4.31442596e-02, 1.07885078e-02, -2.47129947e-02, -4.14506458e-02,\n", " 4.40958813e-02, 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n", " 1.13289580e-02, -5.57265691e-02, 1.71151303e-03, -1.24145029e-02,\n", " -3.57853901e-03, -4.86295968e-02, -5.14956787e-02, 4.79425713e-02,\n", " -3.24050151e-02, 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n", " 8.20766483e-03, -6.27530292e-02, -1.30661400e-02, -3.52081768e-02,\n", " 4.83807474e-02, 9.81860235e-03, 1.14539362e-01, -1.88471414e-02,\n", " 6.07751869e-02, -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n", " 2.64345529e-03, 3.07400171e-02, -4.31060083e-02, -6.19985871e-02,\n", " 5.50477020e-03, 1.62547994e-02, -8.26352183e-03, 7.56437238e-03,\n", " -4.79784003e-03, 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n", " 1.41595434e-02, 5.31185642e-02, 6.78585656e-03, 6.56357184e-02,\n", " -5.06135784e-02, -3.05179805e-02, 7.06539825e-02, -3.55644710e-02,\n", " -4.92612133e-03, 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n", " -1.86746120e-02, 2.49281265e-02, -4.92450967e-03, 1.66887734e-02,\n", " 4.62210961e-02, 4.07794118e-02, 2.52511259e-02, -2.83305068e-02,\n", " -2.78001893e-02, -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n", " 1.09969089e-02, 1.69700030e-02, -8.59475043e-03, 4.70476560e-02,\n", " 3.64770554e-02, 2.09835749e-02, 1.01236468e-02, 2.75151283e-02,\n", " 4.33402918e-02, -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n", " -6.10819347e-02, -2.86280159e-02, 4.68054451e-02, 1.29892454e-02,\n", " -1.71940885e-02, -2.52429228e-02, 3.86423096e-02, -1.35919163e-02,\n", " -5.27431667e-02, 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n", " 3.23252901e-02, 5.03172688e-02, -4.45654802e-02, 2.90075876e-02,\n", " -1.35373492e-02, 6.78209821e-03, -5.89249916e-02, 4.28890549e-02,\n", " -2.36034058e-02, -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n", " 1.45543357e-02, 1.07806427e-02, -6.06855676e-02, -4.95252907e-02,\n", " 1.02004781e-02, 4.60227691e-02, -1.08090881e-02, 4.42408510e-02,\n", " 4.15152796e-02, 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n", " -2.70090066e-02, 2.68773828e-02, -1.97812133e-02, 2.25932393e-02,\n", " -1.33560598e-02, -1.50896851e-02, -3.14053567e-03, 1.54051669e-02,\n", " 1.86488125e-02, -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n", " -2.37891261e-04, 1.84722953e-02, 3.60381305e-02, -5.85213909e-03,\n", " 4.44293395e-02, -1.11264118e-03, -4.79441285e-02, 3.46464328e-02,\n", " -2.53370814e-02, -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n", " -4.38152434e-04, 4.08602282e-02, -2.29470823e-02, -1.89938806e-02,\n", " -1.52037974e-04, 1.05516789e-02, 2.08601039e-02, -6.98119551e-02,\n", " 3.66246551e-02, -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n", " 6.51817098e-02, 4.29646857e-02, 2.56071109e-02, -3.28080021e-02,\n", " 1.20534413e-02, 3.56224040e-03, -1.01593453e-02, -1.96505673e-04,\n", " 4.33485657e-02, -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n", " -1.40319867e-02, -3.63940969e-02, -3.09983976e-02, -4.19548260e-33,\n", " 7.11604580e-02, 4.78382297e-02, 1.89297704e-03, -1.60731785e-02,\n", " 2.53787991e-02, -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n", " 1.68679946e-03, 1.92391127e-02, -2.20667192e-04, 1.32907527e-02,\n", " 5.99487219e-03, 2.75156219e-02, -5.06000873e-03, -3.58465910e-02,\n", " 8.20948277e-03, -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n", " -1.09853260e-01, -3.66037302e-02, 3.55480015e-02, 4.23291475e-02,\n", " 1.48312682e-02, 5.68749309e-02, 3.57767567e-02, 1.40728084e-02,\n", " -4.00471613e-02, 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n", " 1.24238459e-02, 1.20237898e-02, -7.69484974e-03, -3.30727436e-02,\n", " -1.45808076e-02, 3.43246050e-02, 3.21143419e-02, -4.96741422e-02,\n", " -5.27968369e-02, 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n", " 2.77636852e-02, 5.90689071e-02, -2.61273161e-02, -6.95008039e-02,\n", " -3.15576978e-02, -5.62214339e-03, -7.93884136e-03, -3.62196900e-02,\n", " -8.26047733e-03, 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n", " -2.52235290e-02, -3.88054736e-02, -2.00710595e-02, 1.50789914e-03,\n", " -5.51338419e-02, -8.35673045e-03, -1.61523875e-02, -8.79513845e-02,\n", " -5.28004877e-02, -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n", " 4.44932319e-02, 8.69598426e-03, -1.14432694e-02, 4.47212979e-02,\n", " 2.70624813e-02, -3.86100151e-02, -3.07358261e-02, 2.75634117e-02,\n", " 1.48464069e-02, -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n", " 8.05836394e-02, -1.69498641e-02, 4.44465503e-02, -2.09145956e-02,\n", " -3.37407738e-02, 3.85780074e-02, -7.44559616e-02, 1.17512364e-02,\n", " 1.01964204e-02, -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n", " 2.54592765e-02, -1.46158040e-02, 5.46646416e-02, 1.43051194e-03,\n", " 2.99116820e-02, 2.24273186e-02, -5.79927117e-03, -1.33864526e-02,\n", " -2.52460372e-02, -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n", " 3.38429734e-02, -2.11539529e-02, 7.17787817e-02, -7.78904185e-02,\n", " -4.04084288e-02, 4.90567498e-02, -2.61603445e-02, 1.97753590e-02,\n", " 4.97209951e-02, -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n", " 2.68440694e-02, 3.29160057e-02, -8.24410375e-03, -1.33646047e-02,\n", " -6.22822754e-02, -1.13362661e-02, -3.79339382e-02, -6.56360280e-05,\n", " -1.08087100e-02, 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n", " -2.54666172e-02, -3.05371322e-02, -1.53249800e-02, -9.87035502e-03,\n", " 1.95337094e-07, -1.76476724e-02, 5.71432859e-02, -2.49180794e-02,\n", " 5.85253723e-02, 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n", " 4.07801419e-02, 4.13940698e-02, 2.55707726e-02, 2.18985360e-02,\n", " -3.04434425e-03, -3.77355106e-02, -6.24866784e-02, -1.17468778e-02,\n", " -4.82194684e-02, -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n", " -2.48471629e-02, 8.05181568e-04, -4.85844910e-03, -5.16015477e-03,\n", " 7.53483502e-03, -9.46175400e-03, -2.39896346e-02, -3.14654633e-02,\n", " 1.50111094e-02, -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n", " 3.08971256e-02, 1.72299352e-02, 5.93419448e-02, -5.74274361e-02,\n", " -8.16087723e-02, -4.80572283e-02, -2.68838424e-02, -1.96331330e-02,\n", " -9.15831141e-03, 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n", " 9.21937004e-02, 1.37132118e-02, -1.19096776e-02, -4.09874134e-02,\n", " 3.37628126e-02, -4.64820908e-03, -2.50304434e-02, 6.25852346e-02,\n", " -1.24449311e-02, 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n", " 5.08641489e-02, 2.53822445e-03, 5.25634140e-02, 1.14882430e-02,\n", " 5.01894541e-02, -3.55215147e-02, -3.31749097e-02, -3.02003417e-03,\n", " -5.36288768e-02, -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n", " 9.56887701e-35, 2.55127084e-02, -1.44770980e-04, 1.96710341e-02,\n", " -1.33620016e-02, -1.51910949e-02, -3.28495577e-02, -1.52465852e-03,\n", " -2.65272055e-02, -4.35708016e-02, -1.75950192e-02, -2.20594816e-02],\n", " dtype=float32)" ] }, "execution_count": 108, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search_query= \"clinical trials related to alzheimers\"\n", "vec= encoder.encode(search_query)\n", "print(vec.shape)\n", "vec" ] }, { "cell_type": "code", "execution_count": 109, "id": "613fd415-4194-45e6-b9f3-9a7707845ad5", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 768)\n" ] }, { "data": { "text/plain": [ "array([[-8.82369652e-03, 8.50650743e-02, 2.08267733e-03,\n", " 6.77651772e-03, -2.86661759e-02, -8.71188380e-03,\n", " 6.99447095e-02, 5.04214764e-02, 3.58386151e-02,\n", " 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n", " 2.27009598e-03, 2.10810862e-02, 2.66138893e-02,\n", " 1.90623086e-02, 4.44708914e-02, 2.96202525e-02,\n", " 5.42085357e-02, -2.34859088e-03, -9.87798795e-02,\n", " -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n", " 5.31156994e-02, -1.37044629e-02, 2.92537250e-02,\n", " -2.61334293e-02, -1.21854078e-04, -2.36813519e-02,\n", " -3.81283499e-02, -1.79494768e-02, -6.29265187e-03,\n", " 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n", " -1.28973828e-04, 4.01791967e-02, 4.21229303e-02,\n", " -8.72302521e-03, 7.44823692e-03, 7.68032745e-02,\n", " 6.50246907e-03, 3.40298638e-02, -1.80711355e-02,\n", " -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n", " -3.34868580e-02, 1.05205458e-02, 2.08975170e-02,\n", " 4.36686277e-02, 3.47612537e-02, -4.99080680e-02,\n", " 4.44446988e-02, 5.57280704e-03, -2.31200755e-02,\n", " -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n", " 4.67903316e-02, 1.91651955e-02, -5.12171052e-02,\n", " 2.46807020e-02, -5.52081019e-02, 3.50596346e-02,\n", " -7.01438356e-03, -3.36519890e-02, -1.41502097e-02,\n", " -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n", " -1.30138136e-02, 5.91567121e-02, 3.37168351e-02,\n", " 3.01467292e-02, -4.59552221e-02, 1.37365120e-03,\n", " 1.00179566e-02, 6.98126853e-04, 3.58139984e-02,\n", " 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n", " 4.75908853e-02, 5.48331346e-03, -6.41460950e-03,\n", " -1.23906611e-02, 5.82688041e-02, -1.60842277e-02,\n", " -2.95833423e-04, 6.97355811e-03, 2.48331465e-02,\n", " -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n", " 1.52637456e-02, 7.01832073e-03, 5.50601333e-02,\n", " 4.35096538e-03, 2.36319732e-02, -1.38118947e-02,\n", " -7.24233836e-02, 9.39742289e-03, -2.66901590e-02,\n", " 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n", " 3.08373477e-03, 7.12765753e-02, 7.13613164e-03,\n", " 3.62721197e-02, -4.53250594e-02, 2.54001115e-02,\n", " -2.54253373e-02, -1.23151275e-03, -1.34750446e-02,\n", " -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n", " -7.31003610e-03, 2.65329964e-02, -2.79857730e-03,\n", " 4.20840643e-02, 3.20205763e-02, -1.19518824e-02,\n", " -5.77116087e-02, 9.88688134e-03, 1.86814573e-02,\n", " -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n", " -3.33383717e-02, -5.52933291e-02, 5.64148054e-02,\n", " 4.92153503e-02, 3.33690383e-02, -3.92963700e-02,\n", " -6.91099390e-02, 3.79911740e-03, -1.74410697e-02,\n", " -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n", " 2.61192098e-02, -2.74193864e-02, 6.92490395e-03,\n", " -4.64810384e-03, -8.99862905e-04, 1.02159111e-02,\n", " -4.81114909e-02, 1.22787328e-02, -9.32844076e-03,\n", " -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n", " -1.60810221e-02, -2.20200215e-02, 2.32051890e-02,\n", " -5.07331975e-02, -1.01248249e-02, 5.62567115e-02,\n", " -2.60966737e-03, 9.27545596e-03, 5.32410555e-02,\n", " 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n", " -2.12969314e-02, 9.82244611e-02, -2.47648880e-02,\n", " 7.06253499e-02, 8.71159416e-03, -2.73140483e-02,\n", " 5.59884915e-03, -2.14829091e-02, -6.67077005e-02,\n", " 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n", " 3.77488993e-02, -1.37352264e-02, -2.85069812e-02,\n", " 1.81708820e-02, -4.07746173e-02, -4.71230270e-03,\n", " -1.59605164e-02, -1.25815195e-03, -6.59954594e-03,\n", " -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n", " 3.07933297e-02, 1.27626080e-02, -4.34489138e-02,\n", " 9.91576444e-03, 1.25470785e-02, -8.67356583e-02,\n", " 1.26097840e-03, 3.24709825e-02, -6.92409948e-02,\n", " -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n", " -2.35359464e-02, -2.95022167e-02, 2.88009271e-02,\n", " -3.26618887e-02, 7.09307985e-03, 3.09435464e-03,\n", " -5.09097055e-02, 3.54242921e-02, 5.37336655e-02,\n", " 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n", " 2.93767708e-03, 2.27999203e-02, 1.02668423e-02,\n", " 3.35033536e-02, -8.28316063e-02, -4.17127199e-02,\n", " -1.23034064e-02, 2.38543525e-02, -3.72257493e-02,\n", " 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n", " 5.74519299e-03, 2.89844945e-02, -2.21337453e-02,\n", " 2.34603398e-02, 6.33142609e-03, -2.24104542e-02,\n", " 1.47326495e-02, 1.98041964e-02, 3.05697713e-02,\n", " -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n", " 3.88860740e-02, -3.97440195e-02, -4.70216498e-02,\n", " 1.02172708e-02, -3.37972888e-03, -8.54947045e-03,\n", " 4.81354557e-02, 4.99849804e-02, 7.11378129e-03,\n", " -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n", " 5.24284095e-02, 2.16708388e-02, -4.00698110e-02,\n", " 5.15380092e-02, -6.03203699e-02, -6.50304696e-03,\n", " -1.03860423e-02, -7.47132823e-02, 3.59848235e-03,\n", " -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n", " -2.88047381e-02, -2.81904116e-02, 1.52729014e-02,\n", " -1.55570190e-02, 1.34619148e-02, 2.34364290e-02,\n", " 3.10326237e-02, -4.70464528e-02, -2.43550166e-02,\n", " -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n", " -5.30204549e-03, 5.52049950e-02, 4.50828709e-02,\n", " -7.30262510e-03, 5.56289777e-02, -9.46066808e-03,\n", " -3.37345451e-02, -1.87659152e-02, 3.57284099e-02,\n", " 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n", " 2.96422077e-04, 4.22447585e-02, 4.97253910e-02,\n", " 6.03130311e-02, 1.32281650e-02, 2.35939436e-02,\n", " -1.59284715e-02, 4.46444489e-02, -1.68315917e-02,\n", " 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n", " 8.99049267e-03, 4.74606343e-02, 6.70041004e-03,\n", " -1.15184486e-03, 2.69540539e-03, -2.77549177e-02,\n", " -1.33260442e-02, 2.60788556e-02, 4.35438640e-02,\n", " -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n", " 2.93240137e-02, 1.82274636e-03, -1.40310880e-02,\n", " -1.91633645e-02, 1.18790809e-02, -4.65121269e-02,\n", " -4.19883654e-02, -2.69681774e-02, -3.23035605e-02,\n", " -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n", " -2.55833156e-02, -5.73152229e-02, 3.30126472e-02,\n", " -7.90146552e-03, -1.08651863e-02, 1.10474667e-02,\n", " 3.03509296e-03, 1.55274626e-02, 1.05599947e-02,\n", " -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n", " 3.77239436e-02, 9.44003314e-02, -4.80610691e-02,\n", " 4.73537892e-02, 3.40655483e-02, 7.88806472e-03,\n", " -2.84915343e-02, 7.96849206e-02, 1.57442074e-02,\n", " -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n", " -1.72730908e-01, -8.72075930e-02, 2.86346450e-02,\n", " 2.16962174e-02, -4.80199270e-02, 6.49317261e-03,\n", " 1.67240556e-02, -2.56227311e-02, 2.19670162e-02,\n", " -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n", " -2.89566331e-02, 1.19498251e-02, -2.33849231e-02,\n", " -2.69133616e-02, -1.46602485e-02, 1.18886270e-02,\n", " 1.64973717e-02, -3.90495770e-02, -3.45575088e-03,\n", " 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n", " 2.10017413e-02, 2.74998210e-02, 3.03551817e-04,\n", " -1.15796946e-01, -4.66962112e-03, -4.80118394e-02,\n", " -3.55160870e-02, -4.72528581e-03, -4.29739058e-02,\n", " -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n", " 1.98413953e-02, -7.27679394e-03, 2.27117930e-02,\n", " -2.59338003e-02, 4.31442596e-02, 1.07885078e-02,\n", " -2.47129947e-02, -4.14506458e-02, 4.40958813e-02,\n", " 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n", " 1.13289580e-02, -5.57265691e-02, 1.71151303e-03,\n", " -1.24145029e-02, -3.57853901e-03, -4.86295968e-02,\n", " -5.14956787e-02, 4.79425713e-02, -3.24050151e-02,\n", " 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n", " 8.20766483e-03, -6.27530292e-02, -1.30661400e-02,\n", " -3.52081768e-02, 4.83807474e-02, 9.81860235e-03,\n", " 1.14539362e-01, -1.88471414e-02, 6.07751869e-02,\n", " -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n", " 2.64345529e-03, 3.07400171e-02, -4.31060083e-02,\n", " -6.19985871e-02, 5.50477020e-03, 1.62547994e-02,\n", " -8.26352183e-03, 7.56437238e-03, -4.79784003e-03,\n", " 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n", " 1.41595434e-02, 5.31185642e-02, 6.78585656e-03,\n", " 6.56357184e-02, -5.06135784e-02, -3.05179805e-02,\n", " 7.06539825e-02, -3.55644710e-02, -4.92612133e-03,\n", " 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n", " -1.86746120e-02, 2.49281265e-02, -4.92450967e-03,\n", " 1.66887734e-02, 4.62210961e-02, 4.07794118e-02,\n", " 2.52511259e-02, -2.83305068e-02, -2.78001893e-02,\n", " -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n", " 1.09969089e-02, 1.69700030e-02, -8.59475043e-03,\n", " 4.70476560e-02, 3.64770554e-02, 2.09835749e-02,\n", " 1.01236468e-02, 2.75151283e-02, 4.33402918e-02,\n", " -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n", " -6.10819347e-02, -2.86280159e-02, 4.68054451e-02,\n", " 1.29892454e-02, -1.71940885e-02, -2.52429228e-02,\n", " 3.86423096e-02, -1.35919163e-02, -5.27431667e-02,\n", " 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n", " 3.23252901e-02, 5.03172688e-02, -4.45654802e-02,\n", " 2.90075876e-02, -1.35373492e-02, 6.78209821e-03,\n", " -5.89249916e-02, 4.28890549e-02, -2.36034058e-02,\n", " -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n", " 1.45543357e-02, 1.07806427e-02, -6.06855676e-02,\n", " -4.95252907e-02, 1.02004781e-02, 4.60227691e-02,\n", " -1.08090881e-02, 4.42408510e-02, 4.15152796e-02,\n", " 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n", " -2.70090066e-02, 2.68773828e-02, -1.97812133e-02,\n", " 2.25932393e-02, -1.33560598e-02, -1.50896851e-02,\n", " -3.14053567e-03, 1.54051669e-02, 1.86488125e-02,\n", " -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n", " -2.37891261e-04, 1.84722953e-02, 3.60381305e-02,\n", " -5.85213909e-03, 4.44293395e-02, -1.11264118e-03,\n", " -4.79441285e-02, 3.46464328e-02, -2.53370814e-02,\n", " -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n", " -4.38152434e-04, 4.08602282e-02, -2.29470823e-02,\n", " -1.89938806e-02, -1.52037974e-04, 1.05516789e-02,\n", " 2.08601039e-02, -6.98119551e-02, 3.66246551e-02,\n", " -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n", " 6.51817098e-02, 4.29646857e-02, 2.56071109e-02,\n", " -3.28080021e-02, 1.20534413e-02, 3.56224040e-03,\n", " -1.01593453e-02, -1.96505673e-04, 4.33485657e-02,\n", " -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n", " -1.40319867e-02, -3.63940969e-02, -3.09983976e-02,\n", " -4.19548260e-33, 7.11604580e-02, 4.78382297e-02,\n", " 1.89297704e-03, -1.60731785e-02, 2.53787991e-02,\n", " -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n", " 1.68679946e-03, 1.92391127e-02, -2.20667192e-04,\n", " 1.32907527e-02, 5.99487219e-03, 2.75156219e-02,\n", " -5.06000873e-03, -3.58465910e-02, 8.20948277e-03,\n", " -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n", " -1.09853260e-01, -3.66037302e-02, 3.55480015e-02,\n", " 4.23291475e-02, 1.48312682e-02, 5.68749309e-02,\n", " 3.57767567e-02, 1.40728084e-02, -4.00471613e-02,\n", " 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n", " 1.24238459e-02, 1.20237898e-02, -7.69484974e-03,\n", " -3.30727436e-02, -1.45808076e-02, 3.43246050e-02,\n", " 3.21143419e-02, -4.96741422e-02, -5.27968369e-02,\n", " 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n", " 2.77636852e-02, 5.90689071e-02, -2.61273161e-02,\n", " -6.95008039e-02, -3.15576978e-02, -5.62214339e-03,\n", " -7.93884136e-03, -3.62196900e-02, -8.26047733e-03,\n", " 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n", " -2.52235290e-02, -3.88054736e-02, -2.00710595e-02,\n", " 1.50789914e-03, -5.51338419e-02, -8.35673045e-03,\n", " -1.61523875e-02, -8.79513845e-02, -5.28004877e-02,\n", " -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n", " 4.44932319e-02, 8.69598426e-03, -1.14432694e-02,\n", " 4.47212979e-02, 2.70624813e-02, -3.86100151e-02,\n", " -3.07358261e-02, 2.75634117e-02, 1.48464069e-02,\n", " -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n", " 8.05836394e-02, -1.69498641e-02, 4.44465503e-02,\n", " -2.09145956e-02, -3.37407738e-02, 3.85780074e-02,\n", " -7.44559616e-02, 1.17512364e-02, 1.01964204e-02,\n", " -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n", " 2.54592765e-02, -1.46158040e-02, 5.46646416e-02,\n", " 1.43051194e-03, 2.99116820e-02, 2.24273186e-02,\n", " -5.79927117e-03, -1.33864526e-02, -2.52460372e-02,\n", " -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n", " 3.38429734e-02, -2.11539529e-02, 7.17787817e-02,\n", " -7.78904185e-02, -4.04084288e-02, 4.90567498e-02,\n", " -2.61603445e-02, 1.97753590e-02, 4.97209951e-02,\n", " -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n", " 2.68440694e-02, 3.29160057e-02, -8.24410375e-03,\n", " -1.33646047e-02, -6.22822754e-02, -1.13362661e-02,\n", " -3.79339382e-02, -6.56360280e-05, -1.08087100e-02,\n", " 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n", " -2.54666172e-02, -3.05371322e-02, -1.53249800e-02,\n", " -9.87035502e-03, 1.95337094e-07, -1.76476724e-02,\n", " 5.71432859e-02, -2.49180794e-02, 5.85253723e-02,\n", " 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n", " 4.07801419e-02, 4.13940698e-02, 2.55707726e-02,\n", " 2.18985360e-02, -3.04434425e-03, -3.77355106e-02,\n", " -6.24866784e-02, -1.17468778e-02, -4.82194684e-02,\n", " -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n", " -2.48471629e-02, 8.05181568e-04, -4.85844910e-03,\n", " -5.16015477e-03, 7.53483502e-03, -9.46175400e-03,\n", " -2.39896346e-02, -3.14654633e-02, 1.50111094e-02,\n", " -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n", " 3.08971256e-02, 1.72299352e-02, 5.93419448e-02,\n", " -5.74274361e-02, -8.16087723e-02, -4.80572283e-02,\n", " -2.68838424e-02, -1.96331330e-02, -9.15831141e-03,\n", " 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n", " 9.21937004e-02, 1.37132118e-02, -1.19096776e-02,\n", " -4.09874134e-02, 3.37628126e-02, -4.64820908e-03,\n", " -2.50304434e-02, 6.25852346e-02, -1.24449311e-02,\n", " 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n", " 5.08641489e-02, 2.53822445e-03, 5.25634140e-02,\n", " 1.14882430e-02, 5.01894541e-02, -3.55215147e-02,\n", " -3.31749097e-02, -3.02003417e-03, -5.36288768e-02,\n", " -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n", " 9.56887701e-35, 2.55127084e-02, -1.44770980e-04,\n", " 1.96710341e-02, -1.33620016e-02, -1.51910949e-02,\n", " -3.28495577e-02, -1.52465852e-03, -2.65272055e-02,\n", " -4.35708016e-02, -1.75950192e-02, -2.20594816e-02]], dtype=float32)" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "svec= np.array(vec).reshape(1,-1)\n", "print(svec.shape)\n", "svec" ] }, { "cell_type": "code", "execution_count": 110, "id": "fef30d70-6958-4259-abb6-09f8c1870a2b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.7731663 0.79433584]] [[330 331]]\n" ] } ], "source": [ "distances, I= index.search(svec, k=2)\n", "print(distances, I)" ] }, { "cell_type": "code", "execution_count": 111, "id": "eb00598c-9799-4697-b2a3-356bb5aae0f1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
desease_conditiontext
330['alzheimer disease', 'dementia', 'brain disea...nct_id: NCT02164643\\nsummary: A Multicenter na...
331['alzheimer disease', 'dementia', 'brain disea...nct_id: NCT02164643\\nsummary: A Multicenter na...
\n", "
" ], "text/plain": [ " desease_condition \\\n", "330 ['alzheimer disease', 'dementia', 'brain disea... \n", "331 ['alzheimer disease', 'dementia', 'brain disea... \n", "\n", " text \n", "330 nct_id: NCT02164643\\nsummary: A Multicenter na... \n", "331 nct_id: NCT02164643\\nsummary: A Multicenter na... " ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2= df.iloc[I[0]]\n", "df2" ] }, { "cell_type": "code", "execution_count": 113, "id": "af5bf8e2-43b6-47af-affa-5111789371ad", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'nct_id: NCT02164643\\nsummary: A Multicenter national longitudinal cohort study including at least 800 individuals consecutively recruited from French Research Memory Centers and followed-up over 24 month and included in Memento.\\nintervention_type: Drug\\nintervention_name: Florbetapir (18F)\\nintervention_description: nan\\nkeywords: [\"Alzheimer\\'s disease\", \\'Mild Cognitive Impairment\\']'" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2.iloc[1].text" ] }, { "cell_type": "code", "execution_count": null, "id": "f3899f81-e120-475c-97ed-080cb7f46510", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }