{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6", "metadata": {}, "outputs": [], "source": [ "import json, os\n", "\n", "from llama_index import SimpleDirectoryReader\n", "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import MetadataMode" ] }, { "cell_type": "code", "execution_count": null, "id": "2c535ad7-7846-4bef-8ba8-33e182490c3d", "metadata": {}, "outputs": [], "source": [ "from llama_index.finetuning import (\n", " generate_qa_embedding_pairs,\n", " EmbeddingQAFinetuneDataset,\n", ")\n", "from llama_index.finetuning import SentenceTransformersFinetuneEngine" ] }, { "cell_type": "code", "execution_count": null, "id": "25f0c7a3-c52f-4417-aec8-4b6cfbf7a1b5", "metadata": {}, "outputs": [], "source": [ "from llama_index.embeddings import OpenAIEmbedding\n", "from llama_index import ServiceContext, VectorStoreIndex\n", "from llama_index.schema import TextNode\n", "from tqdm.notebook import tqdm\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "id": "62f4d7f0-748a-405e-b5f1-6520fd02bedc", "metadata": {}, "outputs": [], "source": [ "from sentence_transformers.evaluation import InformationRetrievalEvaluator\n", "from sentence_transformers import SentenceTransformer\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": null, "id": "12527049-a5cb-423c-8de5-099aee970c85", "metadata": {}, "outputs": [], "source": [ "from llama_index.llms import OpenAI" ] }, { "cell_type": "code", "execution_count": null, "id": "abde5e6c-3474-460c-9fac-4f3352c38b53", "metadata": {}, "outputs": [], "source": [ "import llama_index\n", "print(llama_index.__version__)" ] }, { "cell_type": "code", "execution_count": null, "id": "7dc65d7b-3cdb-4513-b09f-f7406ad59b35", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6", "metadata": {}, "outputs": [], "source": [ "TRAIN_FILES = [\"../raw_documents/HI_Knowledge_Base.pdf\"]\n", "VAL_FILES = [\"../raw_documents/HI Chapter Summary Version 1.3.pdf\"]\n", "\n", "TRAIN_CORPUS_FPATH = \"../data/train_corpus.json\"\n", "VAL_CORPUS_FPATH = \"../data/val_corpus.json\"" ] }, { "cell_type": "code", "execution_count": null, "id": "663cd20e-c16e-4dda-924e-5f60eb25a772", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "26f614c8-eb45-4cc1-b067-2c7299587982", "metadata": {}, "outputs": [], "source": [ "def load_corpus(files, verbose=False):\n", " if verbose:\n", " print(f\"Loading files {files}\")\n", "\n", " reader = SimpleDirectoryReader(input_files=files)\n", " docs = reader.load_data()\n", " if verbose:\n", " print(f\"Loaded {len(docs)} docs\")\n", "\n", " parser = SentenceSplitter()\n", " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n", "\n", " if verbose:\n", " print(f\"Parsed {len(nodes)} nodes\")\n", "\n", " return nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "a6ba52e5-4d7f-4c30-8979-8d84a1bc3ca4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48", "metadata": {}, "outputs": [], "source": [ "if not os.path.exists(TRAIN_CORPUS_FPATH) or \\\n", " not os.path.exists(VAL_CORPUS_FPATH):\n", "\n", " train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n", " val_nodes = load_corpus(VAL_FILES, verbose=True)\n", " \n", " train_dataset = generate_qa_embedding_pairs(\n", " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=train_nodes\n", " )\n", " val_dataset = generate_qa_embedding_pairs(\n", " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=val_nodes\n", " )\n", " \n", " train_dataset.save_json(TRAIN_CORPUS_FPATH)\n", " val_dataset.save_json(VAL_CORPUS_FPATH)\n", " \n", "else:\n", " train_dataset = EmbeddingQAFinetuneDataset.from_json(TRAIN_CORPUS_FPATH)\n", " val_dataset = EmbeddingQAFinetuneDataset.from_json(VAL_CORPUS_FPATH)" ] }, { "cell_type": "code", "execution_count": null, "id": "c3399443-5936-4dfe-b0ec-821d222e734d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8", "metadata": {}, "outputs": [], "source": [ "finetune_engine = SentenceTransformersFinetuneEngine(\n", " train_dataset,\n", " model_id=\"BAAI/bge-small-en-v1.5\",\n", " model_output_path=\"../models/fine-tuned-embeddings\",\n", " batch_size=5,\n", " val_dataset=val_dataset\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72", "metadata": {}, "outputs": [], "source": [ "finetune_engine.finetune()" ] }, { "cell_type": "code", "execution_count": null, "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9", "metadata": {}, "outputs": [], "source": [ "embed_model = finetune_engine.get_finetuned_model()" ] }, { "cell_type": "code", "execution_count": null, "id": "72d9f97a-0902-4e65-8459-b34613e419f6", "metadata": {}, "outputs": [], "source": [ "embed_model" ] }, { "cell_type": "code", "execution_count": null, "id": "c4f4058c-edbb-43c4-bebe-8c36d410e819", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "97ebae28-80ef-4f35-92ce-a370776e3b22", "metadata": {}, "outputs": [], "source": [ "fine_tuned_embed_model = SentenceTransformer(\"../models/fine-tuned-embeddings\")" ] }, { "cell_type": "code", "execution_count": null, "id": "dad7589f-4855-4432-b710-01aff9c134ee", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ac4a1a5b-974d-452e-8507-0950c962f9b2", "metadata": {}, "outputs": [], "source": [ "def evaluate(\n", " dataset,\n", " embed_model,\n", " top_k=5,\n", " verbose=False,\n", "):\n", " corpus = dataset.corpus\n", " queries = dataset.queries\n", " relevant_docs = dataset.relevant_docs\n", "\n", " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n", " nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]\n", " index = VectorStoreIndex(\n", " nodes, service_context=service_context, show_progress=True\n", " )\n", " retriever = index.as_retriever(similarity_top_k=top_k)\n", "\n", " eval_results = []\n", " for query_id, query in tqdm(queries.items()):\n", " retrieved_nodes = retriever.retrieve(query)\n", " retrieved_ids = [node.node.node_id for node in retrieved_nodes]\n", " expected_id = relevant_docs[query_id][0]\n", " is_hit = expected_id in retrieved_ids # assume 1 relevant doc\n", "\n", " eval_result = {\n", " \"is_hit\": is_hit,\n", " \"retrieved\": retrieved_ids,\n", " \"expected\": expected_id,\n", " \"query\": query_id,\n", " }\n", " eval_results.append(eval_result)\n", " return eval_results" ] }, { "cell_type": "code", "execution_count": null, "id": "a53cf893-ce9f-4d9d-ad4a-e9e17fb058d3", "metadata": {}, "outputs": [], "source": [ "def evaluate_st(\n", " dataset,\n", " model_id,\n", " name,\n", "):\n", " corpus = dataset.corpus\n", " queries = dataset.queries\n", " relevant_docs = dataset.relevant_docs\n", "\n", " evaluator = InformationRetrievalEvaluator(\n", " queries, corpus, relevant_docs, name=name\n", " )\n", " model = SentenceTransformer(model_id)\n", " output_path = \"../results/\"\n", " Path(output_path).mkdir(exist_ok=True, parents=True)\n", " return evaluator(model, output_path=output_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "703f9350-f7ab-43cc-abdf-055323ef67dd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "57d66621-49e6-4a8a-9ef2-83b2b33e33d7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b43ad08e-e96d-412b-9a88-14fe3af85b3d", "metadata": {}, "source": [ "### Using OpenAI Ada embedding" ] }, { "cell_type": "code", "execution_count": null, "id": "91f057aa-4b59-48ea-b3d5-23012a4d487f", "metadata": {}, "outputs": [], "source": [ "ada = OpenAIEmbedding()\n", "ada_val_results = evaluate(val_dataset, ada)" ] }, { "cell_type": "code", "execution_count": null, "id": "5d2f59c6-75d3-4970-bac3-dfe0eef00efe", "metadata": {}, "outputs": [], "source": [ "df_ada = pd.DataFrame(ada_val_results)" ] }, { "cell_type": "code", "execution_count": null, "id": "7a697cd8-6f39-4d5b-84f4-f08cf58adc4a", "metadata": {}, "outputs": [], "source": [ "df_ada[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "3f7186fb-f392-4531-8959-25161e3905e4", "metadata": {}, "outputs": [], "source": [ "hit_rate_ada = df_ada[\"is_hit\"].mean()\n", "hit_rate_ada, len(df_ada)" ] }, { "cell_type": "code", "execution_count": null, "id": "d044399a-e55b-40b7-a09d-6fb838383bfa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "66746f3e-638a-432c-a38d-7cb99d2093f7", "metadata": {}, "source": [ "### Using BAAI bge-small model without fine-tuning" ] }, { "cell_type": "code", "execution_count": null, "id": "b2905831-0eb9-4ea7-a0b9-5db286b0965e", "metadata": {}, "outputs": [], "source": [ "bge = \"local:BAAI/bge-small-en-v1.5\"\n", "bge_val_results = evaluate(val_dataset, bge)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e66270d-d3f6-429e-9e48-e8062866aa02", "metadata": {}, "outputs": [], "source": [ "df_bge = pd.DataFrame(bge_val_results)" ] }, { "cell_type": "code", "execution_count": null, "id": "698c1eb7-eba4-4383-98aa-931fc4ad56a4", "metadata": {}, "outputs": [], "source": [ "df_bge[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "9b1cb546-4605-4c48-bf4e-df812db97f13", "metadata": {}, "outputs": [], "source": [ "hit_rate_bge = df_bge[\"is_hit\"].mean()\n", "hit_rate_bge, len(df_bge)" ] }, { "cell_type": "code", "execution_count": null, "id": "7dd69ad1-2153-4df0-93f7-807fc289d3fd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1b12ca3d-6ca2-41f6-9ddb-b12b9354ca83", "metadata": {}, "outputs": [], "source": [ "evaluate_st(val_dataset, \"BAAI/bge-small-en-v1.5\", name=\"bge\")" ] }, { "cell_type": "code", "execution_count": null, "id": "6023382b-0ff5-4d60-aeac-ad523153f943", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "adf35a2a-3bb7-4251-9521-f35346a7c6e6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b3d290c2-784f-4c41-a258-e11d2c5117e7", "metadata": {}, "source": [ "### Using BAAI bge-small model with `fine-tuning`" ] }, { "cell_type": "code", "execution_count": null, "id": "bd42b288-1f1f-41aa-9fd4-1ae4b1df462b", "metadata": {}, "outputs": [], "source": [ "finetuned = \"local:../models/fine-tuned-embeddings\"\n", "val_results_finetuned = evaluate(val_dataset, finetuned)" ] }, { "cell_type": "code", "execution_count": null, "id": "b1d7112d-b1b8-47db-8a4b-6c024ef99dd6", "metadata": {}, "outputs": [], "source": [ "df_finetuned = pd.DataFrame(val_results_finetuned)" ] }, { "cell_type": "code", "execution_count": null, "id": "62a4dd29-0631-4c5b-88e1-be43d48e1043", "metadata": {}, "outputs": [], "source": [ "hit_rate_finetuned = df_finetuned[\"is_hit\"].mean()\n", "hit_rate_finetuned" ] }, { "cell_type": "code", "execution_count": null, "id": "4332594b-c861-40fb-a58b-ba36717d0519", "metadata": {}, "outputs": [], "source": [ "evaluate_st(val_dataset, \"../models/fine-tuned-embeddings\", name=\"finetuned\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b0003812-84a2-4ebd-9372-07bf874a486b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "ae7eb6ff-181b-42c8-975c-ca3320158698", "metadata": {}, "source": [ "### Summary" ] }, { "cell_type": "code", "execution_count": null, "id": "3ca46cff-b186-463a-847d-a86c310268ec", "metadata": {}, "outputs": [], "source": [ "df_ada[\"model\"] = \"ada\"\n", "df_bge[\"model\"] = \"bge\"\n", "df_finetuned[\"model\"] = \"fine_tuned\"" ] }, { "cell_type": "code", "execution_count": null, "id": "d1d3053e-2395-48a0-af59-fd27180e1e7b", "metadata": {}, "outputs": [], "source": [ "df_all = pd.concat([df_ada, df_bge, df_finetuned])\n", "df_all.groupby(\"model\").mean(\"is_hit\")" ] }, { "cell_type": "code", "execution_count": null, "id": "72575c28-a221-4967-8f04-9579dcefa8f8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "032cac38-c856-4aeb-9bbb-6d70ed53c614", "metadata": {}, "outputs": [], "source": [ "df_st_bge = pd.read_csv(\n", " \"../results/Information-Retrieval_evaluation_bge_results.csv\"\n", ")\n", "df_st_finetuned = pd.read_csv(\n", " \"../results/Information-Retrieval_evaluation_finetuned_results.csv\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a509f239-8b28-4d0a-9101-c8de91c7943b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d2975262-c486-4a9a-a61f-ea535203a0f3", "metadata": {}, "outputs": [], "source": [ "df_st_bge[\"model\"] = \"bge\"\n", "df_st_finetuned[\"model\"] = \"fine_tuned\"\n", "df_st_all = pd.concat([df_st_bge, df_st_finetuned])\n", "df_st_all = df_st_all.set_index(\"model\")\n", "df_st_all" ] }, { "cell_type": "code", "execution_count": null, "id": "6ed2321b-6618-4a2b-9b1c-028425e91b84", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }