{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6", "metadata": {}, "outputs": [], "source": [ "import json, os\n", "\n", "from llama_index.core import SimpleDirectoryReader\n", "from llama_index.core.node_parser import SentenceSplitter\n", "from llama_index.core.schema import MetadataMode" ] }, { "cell_type": "code", "execution_count": 2, "id": "139da55d-f0c3-4b76-b47f-e18ee552eb30", "metadata": {}, "outputs": [], "source": [ "from llama_index.finetuning.embeddings.common import (\n", " EmbeddingQAFinetuneDataset,\n", " generate_qa_embedding_pairs,\n", ")\n", "from llama_index.finetuning import SentenceTransformersFinetuneEngine" ] }, { "cell_type": "code", "execution_count": 3, "id": "1dfb1acc-606b-4106-baf7-87ed487b5d9c", "metadata": {}, "outputs": [], "source": [ "from llama_index.embeddings.openai.base import OpenAIEmbedding" ] }, { "cell_type": "code", "execution_count": 4, "id": "fa06c66a-ab07-46a6-bc53-f6157017883c", "metadata": {}, "outputs": [], "source": [ "from llama_index.core import ServiceContext, VectorStoreIndex" ] }, { "cell_type": "code", "execution_count": 5, "id": "c9928491-520a-441a-8c44-1fc21cfa5def", "metadata": {}, "outputs": [], "source": [ "from llama_index.core.schema import TextNode" ] }, { "cell_type": "code", "execution_count": 6, "id": "25f0c7a3-c52f-4417-aec8-4b6cfbf7a1b5", "metadata": {}, "outputs": [], "source": [ "from tqdm.notebook import tqdm\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 7, "id": "62f4d7f0-748a-405e-b5f1-6520fd02bedc", "metadata": {}, "outputs": [], "source": [ "from sentence_transformers.evaluation import InformationRetrievalEvaluator\n", "from sentence_transformers import SentenceTransformer\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 8, "id": "12527049-a5cb-423c-8de5-099aee970c85", "metadata": {}, "outputs": [], "source": [ "from llama_index.llms.openai import OpenAI" ] }, { "cell_type": "code", "execution_count": null, "id": "7dc65d7b-3cdb-4513-b09f-f7406ad59b35", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 9, "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6", "metadata": {}, "outputs": [], "source": [ "TRAIN_FILES = [\"../raw_documents/HI_Knowledge_Base.pdf\",\n", " \"../raw_documents/HI Chapter Summary Version 1.3.pdf\"]\n", "VAL_FILES = [\"../raw_documents/qna.txt\",\n", " \"../raw_documents/conversation_examples.txt\",\n", " \"../raw_documents/answers.txt\"]\n", "\n", "### based on all docs\n", "TRAIN_CORPUS_FPATH = \"../data/train_corpus_advanced.json\"\n", "\n", "### based on ../raw_documents/HI Chapter Summary Version 1.3.pdf\n", "VAL_CORPUS_FPATH = \"../data/val_corpus.json\"" ] }, { "cell_type": "code", "execution_count": null, "id": "663cd20e-c16e-4dda-924e-5f60eb25a772", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "26f614c8-eb45-4cc1-b067-2c7299587982", "metadata": {}, "outputs": [], "source": [ "def load_corpus(files, verbose=False):\n", " if verbose:\n", " print(f\"Loading files {files}\")\n", "\n", " reader = SimpleDirectoryReader(input_files=files)\n", " docs = reader.load_data()\n", " if verbose:\n", " print(f\"Loaded {len(docs)} docs\")\n", "\n", " parser = SentenceSplitter()\n", " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n", "\n", " if verbose:\n", " print(f\"Parsed {len(nodes)} nodes\")\n", "\n", " return nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "a6ba52e5-4d7f-4c30-8979-8d84a1bc3ca4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "load qa embedding training pairs from saved corpus file..\n", "load qa embedding validation pairs from saved corpus file..\n" ] } ], "source": [ "if not os.path.exists(TRAIN_CORPUS_FPATH):\n", " train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n", " print(\"generating qa embedding pairs for training data..\")\n", " train_dataset = generate_qa_embedding_pairs(\n", " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=train_nodes\n", " )\n", " train_dataset.save_json(TRAIN_CORPUS_FPATH)\n", "else:\n", " print(\"load qa embedding training pairs from saved corpus file..\")\n", " train_dataset = EmbeddingQAFinetuneDataset.from_json(TRAIN_CORPUS_FPATH)\n", "\n", "if not os.path.exists(VAL_CORPUS_FPATH):\n", " val_nodes = load_corpus(VAL_FILES, verbose=True)\n", " print(\"generating qa embedding pairs for validation data..\")\n", " val_dataset = generate_qa_embedding_pairs(\n", " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=val_nodes\n", " )\n", " val_dataset.save_json(VAL_CORPUS_FPATH)\n", "else:\n", " print(\"load qa embedding validation pairs from saved corpus file..\")\n", " val_dataset = EmbeddingQAFinetuneDataset.from_json(VAL_CORPUS_FPATH)" ] }, { "cell_type": "code", "execution_count": null, "id": "c3399443-5936-4dfe-b0ec-821d222e734d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 12, "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "19241142d8534d139252ffe078559bb7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/94.8k [00:00, tokenizer_name='../models/fine-tuned-embeddings-advanced', max_length=512, pooling=, normalize=True, query_instruction=None, text_instruction=None, cache_folder=None)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embed_model" ] }, { "cell_type": "code", "execution_count": null, "id": "c4f4058c-edbb-43c4-bebe-8c36d410e819", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 16, "id": "97ebae28-80ef-4f35-92ce-a370776e3b22", "metadata": {}, "outputs": [], "source": [ "fine_tuned_embed_model = SentenceTransformer(\"../models/fine-tuned-embeddings-advanced\")" ] }, { "cell_type": "code", "execution_count": null, "id": "dad7589f-4855-4432-b710-01aff9c134ee", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 17, "id": "ac4a1a5b-974d-452e-8507-0950c962f9b2", "metadata": {}, "outputs": [], "source": [ "def evaluate(\n", " dataset,\n", " embed_model,\n", " top_k=5,\n", " verbose=False,\n", "):\n", " corpus = dataset.corpus\n", " queries = dataset.queries\n", " relevant_docs = dataset.relevant_docs\n", "\n", " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n", " nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]\n", " index = VectorStoreIndex(\n", " nodes, service_context=service_context, show_progress=True\n", " )\n", " retriever = index.as_retriever(similarity_top_k=top_k)\n", "\n", " eval_results = []\n", " for query_id, query in tqdm(queries.items()):\n", " retrieved_nodes = retriever.retrieve(query)\n", " retrieved_ids = [node.node.node_id for node in retrieved_nodes]\n", " expected_id = relevant_docs[query_id][0]\n", " is_hit = expected_id in retrieved_ids # assume 1 relevant doc\n", "\n", " eval_result = {\n", " \"is_hit\": is_hit,\n", " \"retrieved\": retrieved_ids,\n", " \"expected\": expected_id,\n", " \"query\": query_id,\n", " }\n", " eval_results.append(eval_result)\n", " return eval_results" ] }, { "cell_type": "code", "execution_count": 18, "id": "a53cf893-ce9f-4d9d-ad4a-e9e17fb058d3", "metadata": {}, "outputs": [], "source": [ "def evaluate_st(\n", " dataset,\n", " model_id,\n", " name,\n", "):\n", " corpus = dataset.corpus\n", " queries = dataset.queries\n", " relevant_docs = dataset.relevant_docs\n", "\n", " evaluator = InformationRetrievalEvaluator(\n", " queries, corpus, relevant_docs, name=name\n", " )\n", " model = SentenceTransformer(model_id)\n", " output_path = \"../results/\"\n", " Path(output_path).mkdir(exist_ok=True, parents=True)\n", " return evaluator(model, output_path=output_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "703f9350-f7ab-43cc-abdf-055323ef67dd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "57d66621-49e6-4a8a-9ef2-83b2b33e33d7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b43ad08e-e96d-412b-9a88-14fe3af85b3d", "metadata": {}, "source": [ "### Using OpenAI Ada embedding" ] }, { "cell_type": "code", "execution_count": 19, "id": "91f057aa-4b59-48ea-b3d5-23012a4d487f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/9p/zqv8rk793ts9cxxfr66p40sh0000gn/T/ipykernel_34681/2760886022.py:11: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n", " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3cd092342b1846ed9aa81f8de44eaaea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating embeddings: 0%| | 0/100 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
is_hitretrievedexpectedquery
0False[5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804...6a756f03-638d-480d-8222-1a6bf3790e3c011d84b2-0c26-4c5c-89d1-2a85498f30e0
1True[6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804...6a756f03-638d-480d-8222-1a6bf3790e3c70c5ddd7-eb86-4a41-af70-a23d2392f48d
2True[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...c83dbd8a-7e62-445e-8c12-a8ad604ff65ea8f4290a-1281-4272-aab9-bf089954a45e
3True[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...c83dbd8a-7e62-445e-8c12-a8ad604ff65ec1ef991a-1cc6-4dbf-b179-2df688c84301
4True[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...21778248-2ed9-4147-bdb0-a60337a1a5991ce25e78-c1e1-487e-9455-9418baa0b60c
\n", "" ], "text/plain": [ " is_hit retrieved \\\n", "0 False [5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804... \n", "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804... \n", "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n", "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n", "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n", "\n", " expected query \n", "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n", "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n", "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n", "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n", "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c " ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_ada[:5]" ] }, { "cell_type": "code", "execution_count": 22, "id": "3f7186fb-f392-4531-8959-25161e3905e4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.95, 200)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hit_rate_ada = df_ada[\"is_hit\"].mean()\n", "hit_rate_ada, len(df_ada)" ] }, { "cell_type": "code", "execution_count": null, "id": "d044399a-e55b-40b7-a09d-6fb838383bfa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "66746f3e-638a-432c-a38d-7cb99d2093f7", "metadata": {}, "source": [ "### Using BAAI bge-small model without fine-tuning" ] }, { "cell_type": "code", "execution_count": 23, "id": "b2905831-0eb9-4ea7-a0b9-5db286b0965e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/9p/zqv8rk793ts9cxxfr66p40sh0000gn/T/ipykernel_34681/2760886022.py:11: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n", " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca1ac4b4b54f4169b909e5633b3eb1ad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating embeddings: 0%| | 0/100 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
is_hitretrievedexpectedquery
0False[69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7...6a756f03-638d-480d-8222-1a6bf3790e3c011d84b2-0c26-4c5c-89d1-2a85498f30e0
1True[6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649...6a756f03-638d-480d-8222-1a6bf3790e3c70c5ddd7-eb86-4a41-af70-a23d2392f48d
2True[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...c83dbd8a-7e62-445e-8c12-a8ad604ff65ea8f4290a-1281-4272-aab9-bf089954a45e
3True[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb...c83dbd8a-7e62-445e-8c12-a8ad604ff65ec1ef991a-1cc6-4dbf-b179-2df688c84301
4True[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...21778248-2ed9-4147-bdb0-a60337a1a5991ce25e78-c1e1-487e-9455-9418baa0b60c
\n", "" ], "text/plain": [ " is_hit retrieved \\\n", "0 False [69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7... \n", "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649... \n", "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n", "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb... \n", "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n", "\n", " expected query \n", "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n", "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n", "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n", "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n", "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c " ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_bge[:5]" ] }, { "cell_type": "code", "execution_count": 26, "id": "9b1cb546-4605-4c48-bf4e-df812db97f13", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.915, 200)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hit_rate_bge = df_bge[\"is_hit\"].mean()\n", "hit_rate_bge, len(df_bge)" ] }, { "cell_type": "code", "execution_count": null, "id": "7dd69ad1-2153-4df0-93f7-807fc289d3fd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 27, "id": "1b12ca3d-6ca2-41f6-9ddb-b12b9354ca83", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7955697668171072" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluate_st(val_dataset, \"BAAI/bge-small-en-v1.5\", name=\"bge\")" ] }, { "cell_type": "code", "execution_count": null, "id": "6023382b-0ff5-4d60-aeac-ad523153f943", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "adf35a2a-3bb7-4251-9521-f35346a7c6e6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b3d290c2-784f-4c41-a258-e11d2c5117e7", "metadata": {}, "source": [ "### Using BAAI bge-small model with `fine-tuning`" ] }, { "cell_type": "code", "execution_count": 28, "id": "bd42b288-1f1f-41aa-9fd4-1ae4b1df462b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/9p/zqv8rk793ts9cxxfr66p40sh0000gn/T/ipykernel_34681/2760886022.py:11: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n", " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ddb31814f674c658e4b509c45104c7a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating embeddings: 0%| | 0/100 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
is_hit
model
ada0.950
bge0.915
fine_tuned0.970
\n", "" ], "text/plain": [ " is_hit\n", "model \n", "ada 0.950\n", "bge 0.915\n", "fine_tuned 0.970" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_all = pd.concat([df_ada, df_bge, df_finetuned])\n", "df_all.groupby(\"model\").mean(\"is_hit\")" ] }, { "cell_type": "code", "execution_count": null, "id": "72575c28-a221-4967-8f04-9579dcefa8f8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 35, "id": "032cac38-c856-4aeb-9bbb-6d70ed53c614", "metadata": {}, "outputs": [], "source": [ "df_st_bge = pd.read_csv(\n", " \"../results/Information-Retrieval_evaluation_bge_results.csv\"\n", ")\n", "df_st_finetuned = pd.read_csv(\n", " \"../results/Information-Retrieval_evaluation_finetuned_results.csv\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a509f239-8b28-4d0a-9101-c8de91c7943b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 36, "id": "d2975262-c486-4a9a-a61f-ea535203a0f3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochstepscos_sim-Accuracy@1cos_sim-Accuracy@3cos_sim-Accuracy@5cos_sim-Accuracy@10cos_sim-Precision@1cos_sim-Recall@1cos_sim-Precision@3cos_sim-Recall@3...dot_score-Recall@1dot_score-Precision@3dot_score-Recall@3dot_score-Precision@5dot_score-Recall@5dot_score-Precision@10dot_score-Recall@10dot_score-MRR@10dot_score-NDCG@10dot_score-MAP@100
model
bge-1-10.7050.8650.9200.960.7050.7050.2883330.865...0.7050.2883330.8650.1840.9200.0960.960.7929350.8335950.795570
bge-1-10.7050.8650.9200.960.7050.7050.2883330.865...0.7050.2883330.8650.1840.9200.0960.960.7929350.8335950.795570
bge-1-10.7050.8650.9200.960.7050.7050.2883330.865...0.7050.2883330.8650.1840.9200.0960.960.7929350.8335950.795570
fine_tuned-1-10.7900.9000.9700.980.7900.7900.3000000.900...0.7900.3000000.9000.1940.9700.0980.980.8562640.8867380.857339
fine_tuned-1-10.7900.9000.9700.980.7900.7900.3000000.900...0.7900.3000000.9000.1940.9700.0980.980.8562640.8867380.857339
fine_tuned-1-10.7700.9100.9650.980.7700.7700.3033330.910...0.7700.3033330.9100.1930.9650.0980.980.8475420.8803880.848711
fine_tuned-1-10.8150.9450.9700.990.8150.8150.3150000.945...0.8150.3150000.9450.1940.9700.0990.990.8829350.9095630.883519
\n", "

7 rows × 32 columns

\n", "
" ], "text/plain": [ " epoch steps cos_sim-Accuracy@1 cos_sim-Accuracy@3 \\\n", "model \n", "bge -1 -1 0.705 0.865 \n", "bge -1 -1 0.705 0.865 \n", "bge -1 -1 0.705 0.865 \n", "fine_tuned -1 -1 0.790 0.900 \n", "fine_tuned -1 -1 0.790 0.900 \n", "fine_tuned -1 -1 0.770 0.910 \n", "fine_tuned -1 -1 0.815 0.945 \n", "\n", " cos_sim-Accuracy@5 cos_sim-Accuracy@10 cos_sim-Precision@1 \\\n", "model \n", "bge 0.920 0.96 0.705 \n", "bge 0.920 0.96 0.705 \n", "bge 0.920 0.96 0.705 \n", "fine_tuned 0.970 0.98 0.790 \n", "fine_tuned 0.970 0.98 0.790 \n", "fine_tuned 0.965 0.98 0.770 \n", "fine_tuned 0.970 0.99 0.815 \n", "\n", " cos_sim-Recall@1 cos_sim-Precision@3 cos_sim-Recall@3 ... \\\n", "model ... \n", "bge 0.705 0.288333 0.865 ... \n", "bge 0.705 0.288333 0.865 ... \n", "bge 0.705 0.288333 0.865 ... \n", "fine_tuned 0.790 0.300000 0.900 ... \n", "fine_tuned 0.790 0.300000 0.900 ... \n", "fine_tuned 0.770 0.303333 0.910 ... \n", "fine_tuned 0.815 0.315000 0.945 ... \n", "\n", " dot_score-Recall@1 dot_score-Precision@3 dot_score-Recall@3 \\\n", "model \n", "bge 0.705 0.288333 0.865 \n", "bge 0.705 0.288333 0.865 \n", "bge 0.705 0.288333 0.865 \n", "fine_tuned 0.790 0.300000 0.900 \n", "fine_tuned 0.790 0.300000 0.900 \n", "fine_tuned 0.770 0.303333 0.910 \n", "fine_tuned 0.815 0.315000 0.945 \n", "\n", " dot_score-Precision@5 dot_score-Recall@5 dot_score-Precision@10 \\\n", "model \n", "bge 0.184 0.920 0.096 \n", "bge 0.184 0.920 0.096 \n", "bge 0.184 0.920 0.096 \n", "fine_tuned 0.194 0.970 0.098 \n", "fine_tuned 0.194 0.970 0.098 \n", "fine_tuned 0.193 0.965 0.098 \n", "fine_tuned 0.194 0.970 0.099 \n", "\n", " dot_score-Recall@10 dot_score-MRR@10 dot_score-NDCG@10 \\\n", "model \n", "bge 0.96 0.792935 0.833595 \n", "bge 0.96 0.792935 0.833595 \n", "bge 0.96 0.792935 0.833595 \n", "fine_tuned 0.98 0.856264 0.886738 \n", "fine_tuned 0.98 0.856264 0.886738 \n", "fine_tuned 0.98 0.847542 0.880388 \n", "fine_tuned 0.99 0.882935 0.909563 \n", "\n", " dot_score-MAP@100 \n", "model \n", "bge 0.795570 \n", "bge 0.795570 \n", "bge 0.795570 \n", "fine_tuned 0.857339 \n", "fine_tuned 0.857339 \n", "fine_tuned 0.848711 \n", "fine_tuned 0.883519 \n", "\n", "[7 rows x 32 columns]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_st_bge[\"model\"] = \"bge\"\n", "df_st_finetuned[\"model\"] = \"fine_tuned\"\n", "df_st_all = pd.concat([df_st_bge, df_st_finetuned])\n", "df_st_all = df_st_all.set_index(\"model\")\n", "df_st_all" ] }, { "cell_type": "code", "execution_count": null, "id": "6ed2321b-6618-4a2b-9b1c-028425e91b84", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }