{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/swastikm/work/senetence-transformer-in-action-main/semantic_search_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import time\n", "from datasets import load_from_disk\n", "import pandas as pd\n", "from sentence_transformers import SentenceTransformer\n", "from sentence_transformers.quantization import quantize_embeddings\n", "import faiss\n", "from usearch.index import Index\n", "import numpy as np\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "base_path = os.getcwd()\n", "full_path = os.path.join(base_path, 'conala')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "conala_dataset = load_from_disk(full_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```Text\n", "Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.\n", "Int8 embedding is required to perform rescoring of fetched document. Rescoring is done by performing inner product with F32 embedding of Query\n", "```\n", "[Efficient Passage Retrieval with Hashing for Open-domain Question Answering](https://arxiv.org/abs/2106.00882)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.\n", "int8_view = Index.restore(os.path.join(base_path, 'conala_int8_usearch.index'), view=True)\n", "binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(os.path.join(base_path, 'conala.index'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Model to generate embedding" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/swastikm/work/senetence-transformer-in-action-main/semantic_search_env/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] } ], "source": [ "model = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def search(query, top_k: int = 20):\n", " # 1. Embed the query as float32\n", " query_embedding = model.encode(query)\n", "\n", " # 2. Quantize the query to ubinary. To perform actual search with faiss\n", " query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), \"ubinary\")\n", "\n", "\n", " # 3. Search the binary index \n", " index = binary_index\n", " _scores, binary_ids = index.search(query_embedding_ubinary, top_k)\n", " binary_ids = binary_ids[0]\n", "\n", "\n", " # 4. Load the corresponding int8 embeddings. To perform rescoring to calculate score of fetched documents.\n", " int8_embeddings = int8_view[binary_ids].astype(int)\n", "\n", " # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings\n", " scores = query_embedding @ int8_embeddings.T\n", "\n", " # 6. Sort the scores and return the top_k\n", " start_time = time.time()\n", " indices = scores.argsort()[::-1][:top_k]\n", " top_k_indices = binary_ids[indices]\n", " top_k_scores = scores[indices]\n", "\n", " top_k_codes = conala_dataset[top_k_indices]\n", "\n", " return top_k_codes\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "user_prompt = input('Enter python coding query')\n", "top_k_outputs = search(user_prompt)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "probs = top_k_outputs['prob']\n", "snippets = top_k_outputs['snippet']" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "idx = np.argsort(probs)[::-1]\n", "results = np.array(snippets)[idx]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['[x[1] for x in elements]',\n", " 'map(itemgetter(1), elements)',\n", " 'zip(*elements)[1]']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "filtered_results = []\n", "for item in results:\n", " if len(filtered_results)<3:\n", " if item not in filtered_results:\n", " filtered_results.append(item)\n", "filtered_results" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "output_template = \"User Query: {user_query}\\nBelow are some examples of previous conversations.\\nQuery: {query1} Solution: {solution1}\\nQuery: {query2} Solution: {solution2}\\nYou may use the above examples for reference only. Create your own solution and provide only the solution\"" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "output_template = \"The top three most relevant code snippets from the database are:\\n\\n1. {snippet1}\\n\\n2. {snippet2}\\n\\n3. {snippet3}\"" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "output = f'{output_template.format(snippet1=filtered_results[0],snippet2=filtered_results[1],snippet3=filtered_results[2])}'" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The top three most relevant code snippets from the database are:\n", "\n", "1. [x[1] for x in elements]\n", "\n", "2. map(itemgetter(1), elements)\n", "\n", "3. zip(*elements)[1]\n" ] } ], "source": [ "print(output)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }