lingyit1108 commited on
Commit
e29216a
Β·
1 Parent(s): be36d9d

tweak the vision_api prompt, create configuration files, minor tweak to main script

Browse files
Files changed (28) hide show
  1. config/model_config.yml +17 -0
  2. config/model_config_advanced.yml +17 -0
  3. models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/data_level0.bin +3 -0
  4. raw_documents/overview_background.txt β†’ models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/header.bin +2 -2
  5. models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/length.bin +3 -0
  6. models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/link_lists.bin +0 -0
  7. models/chroma_db_advanced/chroma.sqlite3 +3 -0
  8. models/fine-tuned-embeddings-advanced/1_Pooling/config.json +3 -0
  9. models/fine-tuned-embeddings-advanced/README.md +3 -0
  10. models/fine-tuned-embeddings-advanced/config.json +3 -0
  11. models/fine-tuned-embeddings-advanced/config_sentence_transformers.json +3 -0
  12. models/fine-tuned-embeddings-advanced/eval/Information-Retrieval_evaluation_results.csv +3 -0
  13. models/fine-tuned-embeddings-advanced/model.safetensors +3 -0
  14. models/fine-tuned-embeddings-advanced/modules.json +3 -0
  15. models/fine-tuned-embeddings-advanced/sentence_bert_config.json +3 -0
  16. models/fine-tuned-embeddings-advanced/special_tokens_map.json +3 -0
  17. models/fine-tuned-embeddings-advanced/tokenizer.json +3 -0
  18. models/fine-tuned-embeddings-advanced/tokenizer_config.json +3 -0
  19. models/fine-tuned-embeddings-advanced/vocab.txt +3 -0
  20. notebooks/001_fine-tuning-embedding-model-advanced.ipynb +1470 -0
  21. notebooks/002_persisted-embedding-model-advanced.ipynb +507 -0
  22. notebooks/002_persisted-embedding-model.ipynb +20 -4
  23. raw_documents/answers.txt +3 -0
  24. raw_documents/conversation_examples.txt +3 -0
  25. raw_documents/qna.txt +2 -2
  26. requirements.txt +24 -11
  27. streamlit_app.py +15 -11
  28. vision_api.py +9 -1
config/model_config.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data:
2
+ source:
3
+ - './raw_documents/qna.txt'
4
+ - './raw_documents/HI Chapter Summary Version 1.3.pdf'
5
+ - './raw_documents/conversation_examples.txt'
6
+ - './raw_documents/HI_Knowledge_Base.pdf'
7
+ - './raw_documents/answers.txt'
8
+
9
+ embeddings:
10
+ embedding_base_model: 'BAAI/bge-small-en-v1.5'
11
+ fine_tuned_embedding_model: 'local:models/fine-tuned-embeddings'
12
+
13
+ vector_store:
14
+ persisted_path: './models/chroma_db'
15
+
16
+ questionaire_data:
17
+ db_path: './database/mock_qna.sqlite'
config/model_config_advanced.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data:
2
+ source:
3
+ - './raw_documents/qna.txt'
4
+ - './raw_documents/HI Chapter Summary Version 1.3.pdf'
5
+ - './raw_documents/conversation_examples.txt'
6
+ - './raw_documents/HI_Knowledge_Base.pdf'
7
+ - './raw_documents/answers.txt'
8
+
9
+ embeddings:
10
+ embedding_base_model: 'BAAI/bge-small-en-v1.5'
11
+ fine_tuned_embedding_model: 'local:models/fine-tuned-embeddings-advanced'
12
+
13
+ vector_store:
14
+ persisted_path: './models/chroma_db_advanced'
15
+
16
+ questionaire_data:
17
+ db_path: './database/mock_qna_advanced.sqlite'
models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eec38a208011f4f233e59d2618152fa02e42d91757412778a5db814fe80bf2f
3
+ size 1676000
raw_documents/overview_background.txt β†’ models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/header.bin RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f4a5e6e0a28727dd6eab4bc18bf5ffcf897a4dbed61a854fa52629d2698f0925
3
- size 5970
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e87a1dc8bcae6f2c4bea6d5dd5005454d4dace8637dae29bff3c037ea771411e
3
+ size 100
models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
3
+ size 4000
models/chroma_db_advanced/a88943fe-4428-425d-8b9c-7bb8665a0c79/link_lists.bin ADDED
File without changes
models/chroma_db_advanced/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51aba6bb0bf5e5851de1e4e6cf53215b874c11b7194b3b765a2edfbc59ce9313
3
+ size 15937536
models/fine-tuned-embeddings-advanced/1_Pooling/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfd7e0a022036d0ffa0f998824a918247d5a7473d968cdc92e318fd04098e682
3
+ size 270
models/fine-tuned-embeddings-advanced/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af2a3dc885fad9e063851f6d7d61c8451bd064d9be25a3086a6f4be73e3d66ec
3
+ size 2544
models/fine-tuned-embeddings-advanced/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d1406e6b622e1d931c5535df1578231e0b315bf77ac55d547f36faed55b99ef
3
+ size 706
models/fine-tuned-embeddings-advanced/config_sentence_transformers.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:940d5f50db195fa6e5e6a4f122c095f77880de259d74b14a65779ed48bdd7c56
3
+ size 124
models/fine-tuned-embeddings-advanced/eval/Information-Retrieval_evaluation_results.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6120b99457f04ca31972429df8bcdc01ea1f1789df3f3a7b90859440d23cdedf
3
+ size 4140
models/fine-tuned-embeddings-advanced/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eed74129b591608f8b74c53a800ae0035e63d623618cb64e26638124beb54f6
3
+ size 133462128
models/fine-tuned-embeddings-advanced/modules.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84e40c8e006c9b1d6c122e02cba9b02458120b5fb0c87b746c41e0207cf642cf
3
+ size 349
models/fine-tuned-embeddings-advanced/sentence_bert_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84e39fda68ccbff05bfa723ae9c0e70e23e2ec373b76e0f8c6e71af72a693cbf
3
+ size 52
models/fine-tuned-embeddings-advanced/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a
3
+ size 695
models/fine-tuned-embeddings-advanced/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91f1def9b9391fdabe028cd3f3fcc4efd34e5d1f08c3bf2de513ebb5911a1854
3
+ size 711649
models/fine-tuned-embeddings-advanced/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b29c7bfc889e53b36d9dd3e686dd4300f6525110eaa98c76a5dafceb2029f53
3
+ size 1242
models/fine-tuned-embeddings-advanced/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3
3
+ size 231508
notebooks/001_fine-tuning-embedding-model-advanced.ipynb ADDED
@@ -0,0 +1,1470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json, os\n",
11
+ "\n",
12
+ "from llama_index.core import SimpleDirectoryReader\n",
13
+ "from llama_index.core.node_parser import SentenceSplitter\n",
14
+ "from llama_index.core.schema import MetadataMode"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "id": "139da55d-f0c3-4b76-b47f-e18ee552eb30",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "from llama_index.finetuning.embeddings.common import (\n",
25
+ " EmbeddingQAFinetuneDataset,\n",
26
+ " generate_qa_embedding_pairs,\n",
27
+ ")\n",
28
+ "from llama_index.finetuning import SentenceTransformersFinetuneEngine"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "id": "1dfb1acc-606b-4106-baf7-87ed487b5d9c",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "from llama_index.embeddings.openai.base import OpenAIEmbedding"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 4,
44
+ "id": "fa06c66a-ab07-46a6-bc53-f6157017883c",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "from llama_index.core import ServiceContext, VectorStoreIndex"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 5,
54
+ "id": "c9928491-520a-441a-8c44-1fc21cfa5def",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "from llama_index.core.schema import TextNode"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 6,
64
+ "id": "25f0c7a3-c52f-4417-aec8-4b6cfbf7a1b5",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "from tqdm.notebook import tqdm\n",
69
+ "import pandas as pd"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 7,
75
+ "id": "62f4d7f0-748a-405e-b5f1-6520fd02bedc",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "from sentence_transformers.evaluation import InformationRetrievalEvaluator\n",
80
+ "from sentence_transformers import SentenceTransformer\n",
81
+ "from pathlib import Path"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 8,
87
+ "id": "12527049-a5cb-423c-8de5-099aee970c85",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "from llama_index.llms.openai import OpenAI"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "7dc65d7b-3cdb-4513-b09f-f7406ad59b35",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": []
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 9,
105
+ "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "TRAIN_FILES = [\"../raw_documents/HI_Knowledge_Base.pdf\",\n",
110
+ " \"../raw_documents/HI Chapter Summary Version 1.3.pdf\"]\n",
111
+ "VAL_FILES = [\"../raw_documents/qna.txt\",\n",
112
+ " \"../raw_documents/conversation_examples.txt\",\n",
113
+ " \"../raw_documents/answers.txt\"]\n",
114
+ "\n",
115
+ "### based on all docs\n",
116
+ "TRAIN_CORPUS_FPATH = \"../data/train_corpus_advanced.json\"\n",
117
+ "\n",
118
+ "### based on ../raw_documents/HI Chapter Summary Version 1.3.pdf\n",
119
+ "VAL_CORPUS_FPATH = \"../data/val_corpus.json\""
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "id": "663cd20e-c16e-4dda-924e-5f60eb25a772",
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": []
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 10,
133
+ "id": "26f614c8-eb45-4cc1-b067-2c7299587982",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "def load_corpus(files, verbose=False):\n",
138
+ " if verbose:\n",
139
+ " print(f\"Loading files {files}\")\n",
140
+ "\n",
141
+ " reader = SimpleDirectoryReader(input_files=files)\n",
142
+ " docs = reader.load_data()\n",
143
+ " if verbose:\n",
144
+ " print(f\"Loaded {len(docs)} docs\")\n",
145
+ "\n",
146
+ " parser = SentenceSplitter()\n",
147
+ " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
148
+ "\n",
149
+ " if verbose:\n",
150
+ " print(f\"Parsed {len(nodes)} nodes\")\n",
151
+ "\n",
152
+ " return nodes"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "a6ba52e5-4d7f-4c30-8979-8d84a1bc3ca4",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": []
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 11,
166
+ "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48",
167
+ "metadata": {},
168
+ "outputs": [
169
+ {
170
+ "name": "stdout",
171
+ "output_type": "stream",
172
+ "text": [
173
+ "load qa embedding training pairs from saved corpus file..\n",
174
+ "load qa embedding validation pairs from saved corpus file..\n"
175
+ ]
176
+ }
177
+ ],
178
+ "source": [
179
+ "if not os.path.exists(TRAIN_CORPUS_FPATH):\n",
180
+ " train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n",
181
+ " print(\"generating qa embedding pairs for training data..\")\n",
182
+ " train_dataset = generate_qa_embedding_pairs(\n",
183
+ " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=train_nodes\n",
184
+ " )\n",
185
+ " train_dataset.save_json(TRAIN_CORPUS_FPATH)\n",
186
+ "else:\n",
187
+ " print(\"load qa embedding training pairs from saved corpus file..\")\n",
188
+ " train_dataset = EmbeddingQAFinetuneDataset.from_json(TRAIN_CORPUS_FPATH)\n",
189
+ "\n",
190
+ "if not os.path.exists(VAL_CORPUS_FPATH):\n",
191
+ " val_nodes = load_corpus(VAL_FILES, verbose=True)\n",
192
+ " print(\"generating qa embedding pairs for validation data..\")\n",
193
+ " val_dataset = generate_qa_embedding_pairs(\n",
194
+ " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=val_nodes\n",
195
+ " )\n",
196
+ " val_dataset.save_json(VAL_CORPUS_FPATH)\n",
197
+ "else:\n",
198
+ " print(\"load qa embedding validation pairs from saved corpus file..\")\n",
199
+ " val_dataset = EmbeddingQAFinetuneDataset.from_json(VAL_CORPUS_FPATH)"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "c3399443-5936-4dfe-b0ec-821d222e734d",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": []
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 12,
213
+ "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8",
214
+ "metadata": {},
215
+ "outputs": [
216
+ {
217
+ "data": {
218
+ "application/vnd.jupyter.widget-view+json": {
219
+ "model_id": "19241142d8534d139252ffe078559bb7",
220
+ "version_major": 2,
221
+ "version_minor": 0
222
+ },
223
+ "text/plain": [
224
+ "README.md: 0%| | 0.00/94.8k [00:00<?, ?B/s]"
225
+ ]
226
+ },
227
+ "metadata": {},
228
+ "output_type": "display_data"
229
+ }
230
+ ],
231
+ "source": [
232
+ "finetune_engine = SentenceTransformersFinetuneEngine(\n",
233
+ " train_dataset,\n",
234
+ " model_id=\"BAAI/bge-small-en-v1.5\",\n",
235
+ " model_output_path=\"../models/fine-tuned-embeddings-advanced\",\n",
236
+ " batch_size=5,\n",
237
+ " val_dataset=val_dataset\n",
238
+ ")"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 13,
244
+ "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72",
245
+ "metadata": {},
246
+ "outputs": [
247
+ {
248
+ "data": {
249
+ "application/vnd.jupyter.widget-view+json": {
250
+ "model_id": "2c10018eda384f49a220c4fa66738fe1",
251
+ "version_major": 2,
252
+ "version_minor": 0
253
+ },
254
+ "text/plain": [
255
+ "Epoch: 0%| | 0/2 [00:00<?, ?it/s]"
256
+ ]
257
+ },
258
+ "metadata": {},
259
+ "output_type": "display_data"
260
+ },
261
+ {
262
+ "data": {
263
+ "application/vnd.jupyter.widget-view+json": {
264
+ "model_id": "5f4e5628b306450eab01e3af1ebdaf28",
265
+ "version_major": 2,
266
+ "version_minor": 0
267
+ },
268
+ "text/plain": [
269
+ "Iteration: 0%| | 0/268 [00:00<?, ?it/s]"
270
+ ]
271
+ },
272
+ "metadata": {},
273
+ "output_type": "display_data"
274
+ },
275
+ {
276
+ "data": {
277
+ "application/vnd.jupyter.widget-view+json": {
278
+ "model_id": "bce2bb08b15548f8afd8fd878f2009a4",
279
+ "version_major": 2,
280
+ "version_minor": 0
281
+ },
282
+ "text/plain": [
283
+ "Iteration: 0%| | 0/268 [00:00<?, ?it/s]"
284
+ ]
285
+ },
286
+ "metadata": {},
287
+ "output_type": "display_data"
288
+ }
289
+ ],
290
+ "source": [
291
+ "finetune_engine.finetune()"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": 14,
297
+ "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9",
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": [
301
+ "embed_model = finetune_engine.get_finetuned_model()"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 15,
307
+ "id": "72d9f97a-0902-4e65-8459-b34613e419f6",
308
+ "metadata": {},
309
+ "outputs": [
310
+ {
311
+ "data": {
312
+ "text/plain": [
313
+ "HuggingFaceEmbedding(model_name='../models/fine-tuned-embeddings-advanced', embed_batch_size=10, callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x29f61adf0>, tokenizer_name='../models/fine-tuned-embeddings-advanced', max_length=512, pooling=<Pooling.CLS: 'cls'>, normalize=True, query_instruction=None, text_instruction=None, cache_folder=None)"
314
+ ]
315
+ },
316
+ "execution_count": 15,
317
+ "metadata": {},
318
+ "output_type": "execute_result"
319
+ }
320
+ ],
321
+ "source": [
322
+ "embed_model"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "id": "c4f4058c-edbb-43c4-bebe-8c36d410e819",
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": []
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 16,
336
+ "id": "97ebae28-80ef-4f35-92ce-a370776e3b22",
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "fine_tuned_embed_model = SentenceTransformer(\"../models/fine-tuned-embeddings-advanced\")"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "id": "dad7589f-4855-4432-b710-01aff9c134ee",
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": []
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": 17,
354
+ "id": "ac4a1a5b-974d-452e-8507-0950c962f9b2",
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "def evaluate(\n",
359
+ " dataset,\n",
360
+ " embed_model,\n",
361
+ " top_k=5,\n",
362
+ " verbose=False,\n",
363
+ "):\n",
364
+ " corpus = dataset.corpus\n",
365
+ " queries = dataset.queries\n",
366
+ " relevant_docs = dataset.relevant_docs\n",
367
+ "\n",
368
+ " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n",
369
+ " nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]\n",
370
+ " index = VectorStoreIndex(\n",
371
+ " nodes, service_context=service_context, show_progress=True\n",
372
+ " )\n",
373
+ " retriever = index.as_retriever(similarity_top_k=top_k)\n",
374
+ "\n",
375
+ " eval_results = []\n",
376
+ " for query_id, query in tqdm(queries.items()):\n",
377
+ " retrieved_nodes = retriever.retrieve(query)\n",
378
+ " retrieved_ids = [node.node.node_id for node in retrieved_nodes]\n",
379
+ " expected_id = relevant_docs[query_id][0]\n",
380
+ " is_hit = expected_id in retrieved_ids # assume 1 relevant doc\n",
381
+ "\n",
382
+ " eval_result = {\n",
383
+ " \"is_hit\": is_hit,\n",
384
+ " \"retrieved\": retrieved_ids,\n",
385
+ " \"expected\": expected_id,\n",
386
+ " \"query\": query_id,\n",
387
+ " }\n",
388
+ " eval_results.append(eval_result)\n",
389
+ " return eval_results"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 18,
395
+ "id": "a53cf893-ce9f-4d9d-ad4a-e9e17fb058d3",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "def evaluate_st(\n",
400
+ " dataset,\n",
401
+ " model_id,\n",
402
+ " name,\n",
403
+ "):\n",
404
+ " corpus = dataset.corpus\n",
405
+ " queries = dataset.queries\n",
406
+ " relevant_docs = dataset.relevant_docs\n",
407
+ "\n",
408
+ " evaluator = InformationRetrievalEvaluator(\n",
409
+ " queries, corpus, relevant_docs, name=name\n",
410
+ " )\n",
411
+ " model = SentenceTransformer(model_id)\n",
412
+ " output_path = \"../results/\"\n",
413
+ " Path(output_path).mkdir(exist_ok=True, parents=True)\n",
414
+ " return evaluator(model, output_path=output_path)"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": null,
420
+ "id": "703f9350-f7ab-43cc-abdf-055323ef67dd",
421
+ "metadata": {},
422
+ "outputs": [],
423
+ "source": []
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "57d66621-49e6-4a8a-9ef2-83b2b33e33d7",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": []
432
+ },
433
+ {
434
+ "cell_type": "markdown",
435
+ "id": "b43ad08e-e96d-412b-9a88-14fe3af85b3d",
436
+ "metadata": {},
437
+ "source": [
438
+ "### Using OpenAI Ada embedding"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": 19,
444
+ "id": "91f057aa-4b59-48ea-b3d5-23012a4d487f",
445
+ "metadata": {},
446
+ "outputs": [
447
+ {
448
+ "name": "stderr",
449
+ "output_type": "stream",
450
+ "text": [
451
+ "/var/folders/9p/zqv8rk793ts9cxxfr66p40sh0000gn/T/ipykernel_34681/2760886022.py:11: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n",
452
+ " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n"
453
+ ]
454
+ },
455
+ {
456
+ "data": {
457
+ "application/vnd.jupyter.widget-view+json": {
458
+ "model_id": "3cd092342b1846ed9aa81f8de44eaaea",
459
+ "version_major": 2,
460
+ "version_minor": 0
461
+ },
462
+ "text/plain": [
463
+ "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
464
+ ]
465
+ },
466
+ "metadata": {},
467
+ "output_type": "display_data"
468
+ },
469
+ {
470
+ "name": "stderr",
471
+ "output_type": "stream",
472
+ "text": [
473
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
474
+ "To disable this warning, you can either:\n",
475
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
476
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
477
+ ]
478
+ },
479
+ {
480
+ "data": {
481
+ "application/vnd.jupyter.widget-view+json": {
482
+ "model_id": "00a72686c4bc4e518e8c7f56124247ab",
483
+ "version_major": 2,
484
+ "version_minor": 0
485
+ },
486
+ "text/plain": [
487
+ " 0%| | 0/200 [00:00<?, ?it/s]"
488
+ ]
489
+ },
490
+ "metadata": {},
491
+ "output_type": "display_data"
492
+ }
493
+ ],
494
+ "source": [
495
+ "ada = OpenAIEmbedding()\n",
496
+ "ada_val_results = evaluate(val_dataset, ada)"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": 20,
502
+ "id": "5d2f59c6-75d3-4970-bac3-dfe0eef00efe",
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "df_ada = pd.DataFrame(ada_val_results)"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "execution_count": 21,
512
+ "id": "7a697cd8-6f39-4d5b-84f4-f08cf58adc4a",
513
+ "metadata": {},
514
+ "outputs": [
515
+ {
516
+ "data": {
517
+ "text/html": [
518
+ "<div>\n",
519
+ "<style scoped>\n",
520
+ " .dataframe tbody tr th:only-of-type {\n",
521
+ " vertical-align: middle;\n",
522
+ " }\n",
523
+ "\n",
524
+ " .dataframe tbody tr th {\n",
525
+ " vertical-align: top;\n",
526
+ " }\n",
527
+ "\n",
528
+ " .dataframe thead th {\n",
529
+ " text-align: right;\n",
530
+ " }\n",
531
+ "</style>\n",
532
+ "<table border=\"1\" class=\"dataframe\">\n",
533
+ " <thead>\n",
534
+ " <tr style=\"text-align: right;\">\n",
535
+ " <th></th>\n",
536
+ " <th>is_hit</th>\n",
537
+ " <th>retrieved</th>\n",
538
+ " <th>expected</th>\n",
539
+ " <th>query</th>\n",
540
+ " </tr>\n",
541
+ " </thead>\n",
542
+ " <tbody>\n",
543
+ " <tr>\n",
544
+ " <th>0</th>\n",
545
+ " <td>False</td>\n",
546
+ " <td>[5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804...</td>\n",
547
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
548
+ " <td>011d84b2-0c26-4c5c-89d1-2a85498f30e0</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <th>1</th>\n",
552
+ " <td>True</td>\n",
553
+ " <td>[6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804...</td>\n",
554
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
555
+ " <td>70c5ddd7-eb86-4a41-af70-a23d2392f48d</td>\n",
556
+ " </tr>\n",
557
+ " <tr>\n",
558
+ " <th>2</th>\n",
559
+ " <td>True</td>\n",
560
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
561
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
562
+ " <td>a8f4290a-1281-4272-aab9-bf089954a45e</td>\n",
563
+ " </tr>\n",
564
+ " <tr>\n",
565
+ " <th>3</th>\n",
566
+ " <td>True</td>\n",
567
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
568
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
569
+ " <td>c1ef991a-1cc6-4dbf-b179-2df688c84301</td>\n",
570
+ " </tr>\n",
571
+ " <tr>\n",
572
+ " <th>4</th>\n",
573
+ " <td>True</td>\n",
574
+ " <td>[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...</td>\n",
575
+ " <td>21778248-2ed9-4147-bdb0-a60337a1a599</td>\n",
576
+ " <td>1ce25e78-c1e1-487e-9455-9418baa0b60c</td>\n",
577
+ " </tr>\n",
578
+ " </tbody>\n",
579
+ "</table>\n",
580
+ "</div>"
581
+ ],
582
+ "text/plain": [
583
+ " is_hit retrieved \\\n",
584
+ "0 False [5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804... \n",
585
+ "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804... \n",
586
+ "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
587
+ "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
588
+ "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n",
589
+ "\n",
590
+ " expected query \n",
591
+ "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n",
592
+ "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n",
593
+ "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n",
594
+ "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n",
595
+ "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c "
596
+ ]
597
+ },
598
+ "execution_count": 21,
599
+ "metadata": {},
600
+ "output_type": "execute_result"
601
+ }
602
+ ],
603
+ "source": [
604
+ "df_ada[:5]"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": 22,
610
+ "id": "3f7186fb-f392-4531-8959-25161e3905e4",
611
+ "metadata": {},
612
+ "outputs": [
613
+ {
614
+ "data": {
615
+ "text/plain": [
616
+ "(0.95, 200)"
617
+ ]
618
+ },
619
+ "execution_count": 22,
620
+ "metadata": {},
621
+ "output_type": "execute_result"
622
+ }
623
+ ],
624
+ "source": [
625
+ "hit_rate_ada = df_ada[\"is_hit\"].mean()\n",
626
+ "hit_rate_ada, len(df_ada)"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": null,
632
+ "id": "d044399a-e55b-40b7-a09d-6fb838383bfa",
633
+ "metadata": {},
634
+ "outputs": [],
635
+ "source": []
636
+ },
637
+ {
638
+ "cell_type": "markdown",
639
+ "id": "66746f3e-638a-432c-a38d-7cb99d2093f7",
640
+ "metadata": {},
641
+ "source": [
642
+ "### Using BAAI bge-small model without fine-tuning"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": 23,
648
+ "id": "b2905831-0eb9-4ea7-a0b9-5db286b0965e",
649
+ "metadata": {},
650
+ "outputs": [
651
+ {
652
+ "name": "stderr",
653
+ "output_type": "stream",
654
+ "text": [
655
+ "/var/folders/9p/zqv8rk793ts9cxxfr66p40sh0000gn/T/ipykernel_34681/2760886022.py:11: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n",
656
+ " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n"
657
+ ]
658
+ },
659
+ {
660
+ "data": {
661
+ "application/vnd.jupyter.widget-view+json": {
662
+ "model_id": "ca1ac4b4b54f4169b909e5633b3eb1ad",
663
+ "version_major": 2,
664
+ "version_minor": 0
665
+ },
666
+ "text/plain": [
667
+ "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
668
+ ]
669
+ },
670
+ "metadata": {},
671
+ "output_type": "display_data"
672
+ },
673
+ {
674
+ "data": {
675
+ "application/vnd.jupyter.widget-view+json": {
676
+ "model_id": "4293592aba3244a991fad843f5c881ba",
677
+ "version_major": 2,
678
+ "version_minor": 0
679
+ },
680
+ "text/plain": [
681
+ " 0%| | 0/200 [00:00<?, ?it/s]"
682
+ ]
683
+ },
684
+ "metadata": {},
685
+ "output_type": "display_data"
686
+ }
687
+ ],
688
+ "source": [
689
+ "bge = \"local:BAAI/bge-small-en-v1.5\"\n",
690
+ "bge_val_results = evaluate(val_dataset, bge)"
691
+ ]
692
+ },
693
+ {
694
+ "cell_type": "code",
695
+ "execution_count": 24,
696
+ "id": "4e66270d-d3f6-429e-9e48-e8062866aa02",
697
+ "metadata": {},
698
+ "outputs": [],
699
+ "source": [
700
+ "df_bge = pd.DataFrame(bge_val_results)"
701
+ ]
702
+ },
703
+ {
704
+ "cell_type": "code",
705
+ "execution_count": 25,
706
+ "id": "698c1eb7-eba4-4383-98aa-931fc4ad56a4",
707
+ "metadata": {},
708
+ "outputs": [
709
+ {
710
+ "data": {
711
+ "text/html": [
712
+ "<div>\n",
713
+ "<style scoped>\n",
714
+ " .dataframe tbody tr th:only-of-type {\n",
715
+ " vertical-align: middle;\n",
716
+ " }\n",
717
+ "\n",
718
+ " .dataframe tbody tr th {\n",
719
+ " vertical-align: top;\n",
720
+ " }\n",
721
+ "\n",
722
+ " .dataframe thead th {\n",
723
+ " text-align: right;\n",
724
+ " }\n",
725
+ "</style>\n",
726
+ "<table border=\"1\" class=\"dataframe\">\n",
727
+ " <thead>\n",
728
+ " <tr style=\"text-align: right;\">\n",
729
+ " <th></th>\n",
730
+ " <th>is_hit</th>\n",
731
+ " <th>retrieved</th>\n",
732
+ " <th>expected</th>\n",
733
+ " <th>query</th>\n",
734
+ " </tr>\n",
735
+ " </thead>\n",
736
+ " <tbody>\n",
737
+ " <tr>\n",
738
+ " <th>0</th>\n",
739
+ " <td>False</td>\n",
740
+ " <td>[69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7...</td>\n",
741
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
742
+ " <td>011d84b2-0c26-4c5c-89d1-2a85498f30e0</td>\n",
743
+ " </tr>\n",
744
+ " <tr>\n",
745
+ " <th>1</th>\n",
746
+ " <td>True</td>\n",
747
+ " <td>[6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649...</td>\n",
748
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
749
+ " <td>70c5ddd7-eb86-4a41-af70-a23d2392f48d</td>\n",
750
+ " </tr>\n",
751
+ " <tr>\n",
752
+ " <th>2</th>\n",
753
+ " <td>True</td>\n",
754
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
755
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
756
+ " <td>a8f4290a-1281-4272-aab9-bf089954a45e</td>\n",
757
+ " </tr>\n",
758
+ " <tr>\n",
759
+ " <th>3</th>\n",
760
+ " <td>True</td>\n",
761
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb...</td>\n",
762
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
763
+ " <td>c1ef991a-1cc6-4dbf-b179-2df688c84301</td>\n",
764
+ " </tr>\n",
765
+ " <tr>\n",
766
+ " <th>4</th>\n",
767
+ " <td>True</td>\n",
768
+ " <td>[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...</td>\n",
769
+ " <td>21778248-2ed9-4147-bdb0-a60337a1a599</td>\n",
770
+ " <td>1ce25e78-c1e1-487e-9455-9418baa0b60c</td>\n",
771
+ " </tr>\n",
772
+ " </tbody>\n",
773
+ "</table>\n",
774
+ "</div>"
775
+ ],
776
+ "text/plain": [
777
+ " is_hit retrieved \\\n",
778
+ "0 False [69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7... \n",
779
+ "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649... \n",
780
+ "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
781
+ "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb... \n",
782
+ "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n",
783
+ "\n",
784
+ " expected query \n",
785
+ "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n",
786
+ "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n",
787
+ "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n",
788
+ "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n",
789
+ "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c "
790
+ ]
791
+ },
792
+ "execution_count": 25,
793
+ "metadata": {},
794
+ "output_type": "execute_result"
795
+ }
796
+ ],
797
+ "source": [
798
+ "df_bge[:5]"
799
+ ]
800
+ },
801
+ {
802
+ "cell_type": "code",
803
+ "execution_count": 26,
804
+ "id": "9b1cb546-4605-4c48-bf4e-df812db97f13",
805
+ "metadata": {},
806
+ "outputs": [
807
+ {
808
+ "data": {
809
+ "text/plain": [
810
+ "(0.915, 200)"
811
+ ]
812
+ },
813
+ "execution_count": 26,
814
+ "metadata": {},
815
+ "output_type": "execute_result"
816
+ }
817
+ ],
818
+ "source": [
819
+ "hit_rate_bge = df_bge[\"is_hit\"].mean()\n",
820
+ "hit_rate_bge, len(df_bge)"
821
+ ]
822
+ },
823
+ {
824
+ "cell_type": "code",
825
+ "execution_count": null,
826
+ "id": "7dd69ad1-2153-4df0-93f7-807fc289d3fd",
827
+ "metadata": {},
828
+ "outputs": [],
829
+ "source": []
830
+ },
831
+ {
832
+ "cell_type": "code",
833
+ "execution_count": 27,
834
+ "id": "1b12ca3d-6ca2-41f6-9ddb-b12b9354ca83",
835
+ "metadata": {},
836
+ "outputs": [
837
+ {
838
+ "data": {
839
+ "text/plain": [
840
+ "0.7955697668171072"
841
+ ]
842
+ },
843
+ "execution_count": 27,
844
+ "metadata": {},
845
+ "output_type": "execute_result"
846
+ }
847
+ ],
848
+ "source": [
849
+ "evaluate_st(val_dataset, \"BAAI/bge-small-en-v1.5\", name=\"bge\")"
850
+ ]
851
+ },
852
+ {
853
+ "cell_type": "code",
854
+ "execution_count": null,
855
+ "id": "6023382b-0ff5-4d60-aeac-ad523153f943",
856
+ "metadata": {},
857
+ "outputs": [],
858
+ "source": []
859
+ },
860
+ {
861
+ "cell_type": "code",
862
+ "execution_count": null,
863
+ "id": "adf35a2a-3bb7-4251-9521-f35346a7c6e6",
864
+ "metadata": {},
865
+ "outputs": [],
866
+ "source": []
867
+ },
868
+ {
869
+ "cell_type": "markdown",
870
+ "id": "b3d290c2-784f-4c41-a258-e11d2c5117e7",
871
+ "metadata": {},
872
+ "source": [
873
+ "### Using BAAI bge-small model with `fine-tuning`"
874
+ ]
875
+ },
876
+ {
877
+ "cell_type": "code",
878
+ "execution_count": 28,
879
+ "id": "bd42b288-1f1f-41aa-9fd4-1ae4b1df462b",
880
+ "metadata": {},
881
+ "outputs": [
882
+ {
883
+ "name": "stderr",
884
+ "output_type": "stream",
885
+ "text": [
886
+ "/var/folders/9p/zqv8rk793ts9cxxfr66p40sh0000gn/T/ipykernel_34681/2760886022.py:11: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n",
887
+ " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n"
888
+ ]
889
+ },
890
+ {
891
+ "data": {
892
+ "application/vnd.jupyter.widget-view+json": {
893
+ "model_id": "9ddb31814f674c658e4b509c45104c7a",
894
+ "version_major": 2,
895
+ "version_minor": 0
896
+ },
897
+ "text/plain": [
898
+ "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
899
+ ]
900
+ },
901
+ "metadata": {},
902
+ "output_type": "display_data"
903
+ },
904
+ {
905
+ "data": {
906
+ "application/vnd.jupyter.widget-view+json": {
907
+ "model_id": "6e781eff650b4cd28345ed4a0c919a28",
908
+ "version_major": 2,
909
+ "version_minor": 0
910
+ },
911
+ "text/plain": [
912
+ " 0%| | 0/200 [00:00<?, ?it/s]"
913
+ ]
914
+ },
915
+ "metadata": {},
916
+ "output_type": "display_data"
917
+ }
918
+ ],
919
+ "source": [
920
+ "finetuned = \"local:../models/fine-tuned-embeddings-advanced\"\n",
921
+ "val_results_finetuned = evaluate(val_dataset, finetuned)"
922
+ ]
923
+ },
924
+ {
925
+ "cell_type": "code",
926
+ "execution_count": 29,
927
+ "id": "b1d7112d-b1b8-47db-8a4b-6c024ef99dd6",
928
+ "metadata": {},
929
+ "outputs": [],
930
+ "source": [
931
+ "df_finetuned = pd.DataFrame(val_results_finetuned)"
932
+ ]
933
+ },
934
+ {
935
+ "cell_type": "code",
936
+ "execution_count": 30,
937
+ "id": "62a4dd29-0631-4c5b-88e1-be43d48e1043",
938
+ "metadata": {},
939
+ "outputs": [
940
+ {
941
+ "data": {
942
+ "text/plain": [
943
+ "0.97"
944
+ ]
945
+ },
946
+ "execution_count": 30,
947
+ "metadata": {},
948
+ "output_type": "execute_result"
949
+ }
950
+ ],
951
+ "source": [
952
+ "hit_rate_finetuned = df_finetuned[\"is_hit\"].mean()\n",
953
+ "hit_rate_finetuned"
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": 31,
959
+ "id": "4332594b-c861-40fb-a58b-ba36717d0519",
960
+ "metadata": {},
961
+ "outputs": [
962
+ {
963
+ "data": {
964
+ "text/plain": [
965
+ "0.8835191391941393"
966
+ ]
967
+ },
968
+ "execution_count": 31,
969
+ "metadata": {},
970
+ "output_type": "execute_result"
971
+ }
972
+ ],
973
+ "source": [
974
+ "evaluate_st(val_dataset, \"../models/fine-tuned-embeddings-advanced\", name=\"finetuned\")"
975
+ ]
976
+ },
977
+ {
978
+ "cell_type": "code",
979
+ "execution_count": null,
980
+ "id": "b0003812-84a2-4ebd-9372-07bf874a486b",
981
+ "metadata": {},
982
+ "outputs": [],
983
+ "source": []
984
+ },
985
+ {
986
+ "cell_type": "markdown",
987
+ "id": "ae7eb6ff-181b-42c8-975c-ca3320158698",
988
+ "metadata": {},
989
+ "source": [
990
+ "### Summary"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "code",
995
+ "execution_count": 32,
996
+ "id": "3ca46cff-b186-463a-847d-a86c310268ec",
997
+ "metadata": {},
998
+ "outputs": [],
999
+ "source": [
1000
+ "df_ada[\"model\"] = \"ada\"\n",
1001
+ "df_bge[\"model\"] = \"bge\"\n",
1002
+ "df_finetuned[\"model\"] = \"fine_tuned\""
1003
+ ]
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "execution_count": 33,
1008
+ "id": "d1d3053e-2395-48a0-af59-fd27180e1e7b",
1009
+ "metadata": {},
1010
+ "outputs": [
1011
+ {
1012
+ "data": {
1013
+ "text/html": [
1014
+ "<div>\n",
1015
+ "<style scoped>\n",
1016
+ " .dataframe tbody tr th:only-of-type {\n",
1017
+ " vertical-align: middle;\n",
1018
+ " }\n",
1019
+ "\n",
1020
+ " .dataframe tbody tr th {\n",
1021
+ " vertical-align: top;\n",
1022
+ " }\n",
1023
+ "\n",
1024
+ " .dataframe thead th {\n",
1025
+ " text-align: right;\n",
1026
+ " }\n",
1027
+ "</style>\n",
1028
+ "<table border=\"1\" class=\"dataframe\">\n",
1029
+ " <thead>\n",
1030
+ " <tr style=\"text-align: right;\">\n",
1031
+ " <th></th>\n",
1032
+ " <th>is_hit</th>\n",
1033
+ " </tr>\n",
1034
+ " <tr>\n",
1035
+ " <th>model</th>\n",
1036
+ " <th></th>\n",
1037
+ " </tr>\n",
1038
+ " </thead>\n",
1039
+ " <tbody>\n",
1040
+ " <tr>\n",
1041
+ " <th>ada</th>\n",
1042
+ " <td>0.950</td>\n",
1043
+ " </tr>\n",
1044
+ " <tr>\n",
1045
+ " <th>bge</th>\n",
1046
+ " <td>0.915</td>\n",
1047
+ " </tr>\n",
1048
+ " <tr>\n",
1049
+ " <th>fine_tuned</th>\n",
1050
+ " <td>0.970</td>\n",
1051
+ " </tr>\n",
1052
+ " </tbody>\n",
1053
+ "</table>\n",
1054
+ "</div>"
1055
+ ],
1056
+ "text/plain": [
1057
+ " is_hit\n",
1058
+ "model \n",
1059
+ "ada 0.950\n",
1060
+ "bge 0.915\n",
1061
+ "fine_tuned 0.970"
1062
+ ]
1063
+ },
1064
+ "execution_count": 33,
1065
+ "metadata": {},
1066
+ "output_type": "execute_result"
1067
+ }
1068
+ ],
1069
+ "source": [
1070
+ "df_all = pd.concat([df_ada, df_bge, df_finetuned])\n",
1071
+ "df_all.groupby(\"model\").mean(\"is_hit\")"
1072
+ ]
1073
+ },
1074
+ {
1075
+ "cell_type": "code",
1076
+ "execution_count": null,
1077
+ "id": "72575c28-a221-4967-8f04-9579dcefa8f8",
1078
+ "metadata": {},
1079
+ "outputs": [],
1080
+ "source": []
1081
+ },
1082
+ {
1083
+ "cell_type": "code",
1084
+ "execution_count": 35,
1085
+ "id": "032cac38-c856-4aeb-9bbb-6d70ed53c614",
1086
+ "metadata": {},
1087
+ "outputs": [],
1088
+ "source": [
1089
+ "df_st_bge = pd.read_csv(\n",
1090
+ " \"../results/Information-Retrieval_evaluation_bge_results.csv\"\n",
1091
+ ")\n",
1092
+ "df_st_finetuned = pd.read_csv(\n",
1093
+ " \"../results/Information-Retrieval_evaluation_finetuned_results.csv\"\n",
1094
+ ")"
1095
+ ]
1096
+ },
1097
+ {
1098
+ "cell_type": "code",
1099
+ "execution_count": null,
1100
+ "id": "a509f239-8b28-4d0a-9101-c8de91c7943b",
1101
+ "metadata": {},
1102
+ "outputs": [],
1103
+ "source": []
1104
+ },
1105
+ {
1106
+ "cell_type": "code",
1107
+ "execution_count": 36,
1108
+ "id": "d2975262-c486-4a9a-a61f-ea535203a0f3",
1109
+ "metadata": {},
1110
+ "outputs": [
1111
+ {
1112
+ "data": {
1113
+ "text/html": [
1114
+ "<div>\n",
1115
+ "<style scoped>\n",
1116
+ " .dataframe tbody tr th:only-of-type {\n",
1117
+ " vertical-align: middle;\n",
1118
+ " }\n",
1119
+ "\n",
1120
+ " .dataframe tbody tr th {\n",
1121
+ " vertical-align: top;\n",
1122
+ " }\n",
1123
+ "\n",
1124
+ " .dataframe thead th {\n",
1125
+ " text-align: right;\n",
1126
+ " }\n",
1127
+ "</style>\n",
1128
+ "<table border=\"1\" class=\"dataframe\">\n",
1129
+ " <thead>\n",
1130
+ " <tr style=\"text-align: right;\">\n",
1131
+ " <th></th>\n",
1132
+ " <th>epoch</th>\n",
1133
+ " <th>steps</th>\n",
1134
+ " <th>cos_sim-Accuracy@1</th>\n",
1135
+ " <th>cos_sim-Accuracy@3</th>\n",
1136
+ " <th>cos_sim-Accuracy@5</th>\n",
1137
+ " <th>cos_sim-Accuracy@10</th>\n",
1138
+ " <th>cos_sim-Precision@1</th>\n",
1139
+ " <th>cos_sim-Recall@1</th>\n",
1140
+ " <th>cos_sim-Precision@3</th>\n",
1141
+ " <th>cos_sim-Recall@3</th>\n",
1142
+ " <th>...</th>\n",
1143
+ " <th>dot_score-Recall@1</th>\n",
1144
+ " <th>dot_score-Precision@3</th>\n",
1145
+ " <th>dot_score-Recall@3</th>\n",
1146
+ " <th>dot_score-Precision@5</th>\n",
1147
+ " <th>dot_score-Recall@5</th>\n",
1148
+ " <th>dot_score-Precision@10</th>\n",
1149
+ " <th>dot_score-Recall@10</th>\n",
1150
+ " <th>dot_score-MRR@10</th>\n",
1151
+ " <th>dot_score-NDCG@10</th>\n",
1152
+ " <th>dot_score-MAP@100</th>\n",
1153
+ " </tr>\n",
1154
+ " <tr>\n",
1155
+ " <th>model</th>\n",
1156
+ " <th></th>\n",
1157
+ " <th></th>\n",
1158
+ " <th></th>\n",
1159
+ " <th></th>\n",
1160
+ " <th></th>\n",
1161
+ " <th></th>\n",
1162
+ " <th></th>\n",
1163
+ " <th></th>\n",
1164
+ " <th></th>\n",
1165
+ " <th></th>\n",
1166
+ " <th></th>\n",
1167
+ " <th></th>\n",
1168
+ " <th></th>\n",
1169
+ " <th></th>\n",
1170
+ " <th></th>\n",
1171
+ " <th></th>\n",
1172
+ " <th></th>\n",
1173
+ " <th></th>\n",
1174
+ " <th></th>\n",
1175
+ " <th></th>\n",
1176
+ " <th></th>\n",
1177
+ " </tr>\n",
1178
+ " </thead>\n",
1179
+ " <tbody>\n",
1180
+ " <tr>\n",
1181
+ " <th>bge</th>\n",
1182
+ " <td>-1</td>\n",
1183
+ " <td>-1</td>\n",
1184
+ " <td>0.705</td>\n",
1185
+ " <td>0.865</td>\n",
1186
+ " <td>0.920</td>\n",
1187
+ " <td>0.96</td>\n",
1188
+ " <td>0.705</td>\n",
1189
+ " <td>0.705</td>\n",
1190
+ " <td>0.288333</td>\n",
1191
+ " <td>0.865</td>\n",
1192
+ " <td>...</td>\n",
1193
+ " <td>0.705</td>\n",
1194
+ " <td>0.288333</td>\n",
1195
+ " <td>0.865</td>\n",
1196
+ " <td>0.184</td>\n",
1197
+ " <td>0.920</td>\n",
1198
+ " <td>0.096</td>\n",
1199
+ " <td>0.96</td>\n",
1200
+ " <td>0.792935</td>\n",
1201
+ " <td>0.833595</td>\n",
1202
+ " <td>0.795570</td>\n",
1203
+ " </tr>\n",
1204
+ " <tr>\n",
1205
+ " <th>bge</th>\n",
1206
+ " <td>-1</td>\n",
1207
+ " <td>-1</td>\n",
1208
+ " <td>0.705</td>\n",
1209
+ " <td>0.865</td>\n",
1210
+ " <td>0.920</td>\n",
1211
+ " <td>0.96</td>\n",
1212
+ " <td>0.705</td>\n",
1213
+ " <td>0.705</td>\n",
1214
+ " <td>0.288333</td>\n",
1215
+ " <td>0.865</td>\n",
1216
+ " <td>...</td>\n",
1217
+ " <td>0.705</td>\n",
1218
+ " <td>0.288333</td>\n",
1219
+ " <td>0.865</td>\n",
1220
+ " <td>0.184</td>\n",
1221
+ " <td>0.920</td>\n",
1222
+ " <td>0.096</td>\n",
1223
+ " <td>0.96</td>\n",
1224
+ " <td>0.792935</td>\n",
1225
+ " <td>0.833595</td>\n",
1226
+ " <td>0.795570</td>\n",
1227
+ " </tr>\n",
1228
+ " <tr>\n",
1229
+ " <th>bge</th>\n",
1230
+ " <td>-1</td>\n",
1231
+ " <td>-1</td>\n",
1232
+ " <td>0.705</td>\n",
1233
+ " <td>0.865</td>\n",
1234
+ " <td>0.920</td>\n",
1235
+ " <td>0.96</td>\n",
1236
+ " <td>0.705</td>\n",
1237
+ " <td>0.705</td>\n",
1238
+ " <td>0.288333</td>\n",
1239
+ " <td>0.865</td>\n",
1240
+ " <td>...</td>\n",
1241
+ " <td>0.705</td>\n",
1242
+ " <td>0.288333</td>\n",
1243
+ " <td>0.865</td>\n",
1244
+ " <td>0.184</td>\n",
1245
+ " <td>0.920</td>\n",
1246
+ " <td>0.096</td>\n",
1247
+ " <td>0.96</td>\n",
1248
+ " <td>0.792935</td>\n",
1249
+ " <td>0.833595</td>\n",
1250
+ " <td>0.795570</td>\n",
1251
+ " </tr>\n",
1252
+ " <tr>\n",
1253
+ " <th>fine_tuned</th>\n",
1254
+ " <td>-1</td>\n",
1255
+ " <td>-1</td>\n",
1256
+ " <td>0.790</td>\n",
1257
+ " <td>0.900</td>\n",
1258
+ " <td>0.970</td>\n",
1259
+ " <td>0.98</td>\n",
1260
+ " <td>0.790</td>\n",
1261
+ " <td>0.790</td>\n",
1262
+ " <td>0.300000</td>\n",
1263
+ " <td>0.900</td>\n",
1264
+ " <td>...</td>\n",
1265
+ " <td>0.790</td>\n",
1266
+ " <td>0.300000</td>\n",
1267
+ " <td>0.900</td>\n",
1268
+ " <td>0.194</td>\n",
1269
+ " <td>0.970</td>\n",
1270
+ " <td>0.098</td>\n",
1271
+ " <td>0.98</td>\n",
1272
+ " <td>0.856264</td>\n",
1273
+ " <td>0.886738</td>\n",
1274
+ " <td>0.857339</td>\n",
1275
+ " </tr>\n",
1276
+ " <tr>\n",
1277
+ " <th>fine_tuned</th>\n",
1278
+ " <td>-1</td>\n",
1279
+ " <td>-1</td>\n",
1280
+ " <td>0.790</td>\n",
1281
+ " <td>0.900</td>\n",
1282
+ " <td>0.970</td>\n",
1283
+ " <td>0.98</td>\n",
1284
+ " <td>0.790</td>\n",
1285
+ " <td>0.790</td>\n",
1286
+ " <td>0.300000</td>\n",
1287
+ " <td>0.900</td>\n",
1288
+ " <td>...</td>\n",
1289
+ " <td>0.790</td>\n",
1290
+ " <td>0.300000</td>\n",
1291
+ " <td>0.900</td>\n",
1292
+ " <td>0.194</td>\n",
1293
+ " <td>0.970</td>\n",
1294
+ " <td>0.098</td>\n",
1295
+ " <td>0.98</td>\n",
1296
+ " <td>0.856264</td>\n",
1297
+ " <td>0.886738</td>\n",
1298
+ " <td>0.857339</td>\n",
1299
+ " </tr>\n",
1300
+ " <tr>\n",
1301
+ " <th>fine_tuned</th>\n",
1302
+ " <td>-1</td>\n",
1303
+ " <td>-1</td>\n",
1304
+ " <td>0.770</td>\n",
1305
+ " <td>0.910</td>\n",
1306
+ " <td>0.965</td>\n",
1307
+ " <td>0.98</td>\n",
1308
+ " <td>0.770</td>\n",
1309
+ " <td>0.770</td>\n",
1310
+ " <td>0.303333</td>\n",
1311
+ " <td>0.910</td>\n",
1312
+ " <td>...</td>\n",
1313
+ " <td>0.770</td>\n",
1314
+ " <td>0.303333</td>\n",
1315
+ " <td>0.910</td>\n",
1316
+ " <td>0.193</td>\n",
1317
+ " <td>0.965</td>\n",
1318
+ " <td>0.098</td>\n",
1319
+ " <td>0.98</td>\n",
1320
+ " <td>0.847542</td>\n",
1321
+ " <td>0.880388</td>\n",
1322
+ " <td>0.848711</td>\n",
1323
+ " </tr>\n",
1324
+ " <tr>\n",
1325
+ " <th>fine_tuned</th>\n",
1326
+ " <td>-1</td>\n",
1327
+ " <td>-1</td>\n",
1328
+ " <td>0.815</td>\n",
1329
+ " <td>0.945</td>\n",
1330
+ " <td>0.970</td>\n",
1331
+ " <td>0.99</td>\n",
1332
+ " <td>0.815</td>\n",
1333
+ " <td>0.815</td>\n",
1334
+ " <td>0.315000</td>\n",
1335
+ " <td>0.945</td>\n",
1336
+ " <td>...</td>\n",
1337
+ " <td>0.815</td>\n",
1338
+ " <td>0.315000</td>\n",
1339
+ " <td>0.945</td>\n",
1340
+ " <td>0.194</td>\n",
1341
+ " <td>0.970</td>\n",
1342
+ " <td>0.099</td>\n",
1343
+ " <td>0.99</td>\n",
1344
+ " <td>0.882935</td>\n",
1345
+ " <td>0.909563</td>\n",
1346
+ " <td>0.883519</td>\n",
1347
+ " </tr>\n",
1348
+ " </tbody>\n",
1349
+ "</table>\n",
1350
+ "<p>7 rows Γ— 32 columns</p>\n",
1351
+ "</div>"
1352
+ ],
1353
+ "text/plain": [
1354
+ " epoch steps cos_sim-Accuracy@1 cos_sim-Accuracy@3 \\\n",
1355
+ "model \n",
1356
+ "bge -1 -1 0.705 0.865 \n",
1357
+ "bge -1 -1 0.705 0.865 \n",
1358
+ "bge -1 -1 0.705 0.865 \n",
1359
+ "fine_tuned -1 -1 0.790 0.900 \n",
1360
+ "fine_tuned -1 -1 0.790 0.900 \n",
1361
+ "fine_tuned -1 -1 0.770 0.910 \n",
1362
+ "fine_tuned -1 -1 0.815 0.945 \n",
1363
+ "\n",
1364
+ " cos_sim-Accuracy@5 cos_sim-Accuracy@10 cos_sim-Precision@1 \\\n",
1365
+ "model \n",
1366
+ "bge 0.920 0.96 0.705 \n",
1367
+ "bge 0.920 0.96 0.705 \n",
1368
+ "bge 0.920 0.96 0.705 \n",
1369
+ "fine_tuned 0.970 0.98 0.790 \n",
1370
+ "fine_tuned 0.970 0.98 0.790 \n",
1371
+ "fine_tuned 0.965 0.98 0.770 \n",
1372
+ "fine_tuned 0.970 0.99 0.815 \n",
1373
+ "\n",
1374
+ " cos_sim-Recall@1 cos_sim-Precision@3 cos_sim-Recall@3 ... \\\n",
1375
+ "model ... \n",
1376
+ "bge 0.705 0.288333 0.865 ... \n",
1377
+ "bge 0.705 0.288333 0.865 ... \n",
1378
+ "bge 0.705 0.288333 0.865 ... \n",
1379
+ "fine_tuned 0.790 0.300000 0.900 ... \n",
1380
+ "fine_tuned 0.790 0.300000 0.900 ... \n",
1381
+ "fine_tuned 0.770 0.303333 0.910 ... \n",
1382
+ "fine_tuned 0.815 0.315000 0.945 ... \n",
1383
+ "\n",
1384
+ " dot_score-Recall@1 dot_score-Precision@3 dot_score-Recall@3 \\\n",
1385
+ "model \n",
1386
+ "bge 0.705 0.288333 0.865 \n",
1387
+ "bge 0.705 0.288333 0.865 \n",
1388
+ "bge 0.705 0.288333 0.865 \n",
1389
+ "fine_tuned 0.790 0.300000 0.900 \n",
1390
+ "fine_tuned 0.790 0.300000 0.900 \n",
1391
+ "fine_tuned 0.770 0.303333 0.910 \n",
1392
+ "fine_tuned 0.815 0.315000 0.945 \n",
1393
+ "\n",
1394
+ " dot_score-Precision@5 dot_score-Recall@5 dot_score-Precision@10 \\\n",
1395
+ "model \n",
1396
+ "bge 0.184 0.920 0.096 \n",
1397
+ "bge 0.184 0.920 0.096 \n",
1398
+ "bge 0.184 0.920 0.096 \n",
1399
+ "fine_tuned 0.194 0.970 0.098 \n",
1400
+ "fine_tuned 0.194 0.970 0.098 \n",
1401
+ "fine_tuned 0.193 0.965 0.098 \n",
1402
+ "fine_tuned 0.194 0.970 0.099 \n",
1403
+ "\n",
1404
+ " dot_score-Recall@10 dot_score-MRR@10 dot_score-NDCG@10 \\\n",
1405
+ "model \n",
1406
+ "bge 0.96 0.792935 0.833595 \n",
1407
+ "bge 0.96 0.792935 0.833595 \n",
1408
+ "bge 0.96 0.792935 0.833595 \n",
1409
+ "fine_tuned 0.98 0.856264 0.886738 \n",
1410
+ "fine_tuned 0.98 0.856264 0.886738 \n",
1411
+ "fine_tuned 0.98 0.847542 0.880388 \n",
1412
+ "fine_tuned 0.99 0.882935 0.909563 \n",
1413
+ "\n",
1414
+ " dot_score-MAP@100 \n",
1415
+ "model \n",
1416
+ "bge 0.795570 \n",
1417
+ "bge 0.795570 \n",
1418
+ "bge 0.795570 \n",
1419
+ "fine_tuned 0.857339 \n",
1420
+ "fine_tuned 0.857339 \n",
1421
+ "fine_tuned 0.848711 \n",
1422
+ "fine_tuned 0.883519 \n",
1423
+ "\n",
1424
+ "[7 rows x 32 columns]"
1425
+ ]
1426
+ },
1427
+ "execution_count": 36,
1428
+ "metadata": {},
1429
+ "output_type": "execute_result"
1430
+ }
1431
+ ],
1432
+ "source": [
1433
+ "df_st_bge[\"model\"] = \"bge\"\n",
1434
+ "df_st_finetuned[\"model\"] = \"fine_tuned\"\n",
1435
+ "df_st_all = pd.concat([df_st_bge, df_st_finetuned])\n",
1436
+ "df_st_all = df_st_all.set_index(\"model\")\n",
1437
+ "df_st_all"
1438
+ ]
1439
+ },
1440
+ {
1441
+ "cell_type": "code",
1442
+ "execution_count": null,
1443
+ "id": "6ed2321b-6618-4a2b-9b1c-028425e91b84",
1444
+ "metadata": {},
1445
+ "outputs": [],
1446
+ "source": []
1447
+ }
1448
+ ],
1449
+ "metadata": {
1450
+ "kernelspec": {
1451
+ "display_name": "Python 3 (ipykernel)",
1452
+ "language": "python",
1453
+ "name": "python3"
1454
+ },
1455
+ "language_info": {
1456
+ "codemirror_mode": {
1457
+ "name": "ipython",
1458
+ "version": 3
1459
+ },
1460
+ "file_extension": ".py",
1461
+ "mimetype": "text/x-python",
1462
+ "name": "python",
1463
+ "nbconvert_exporter": "python",
1464
+ "pygments_lexer": "ipython3",
1465
+ "version": "3.9.18"
1466
+ }
1467
+ },
1468
+ "nbformat": 4,
1469
+ "nbformat_minor": 5
1470
+ }
notebooks/002_persisted-embedding-model-advanced.ipynb ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "8acae3ed-2953-45a3-aba9-0327b6ae3679",
6
+ "metadata": {},
7
+ "source": [
8
+ "### ChromaDB method - create vectorstore based on Chroma"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "7de9c591-5a77-4bbe-80f1-4897e15f0b97",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import chromadb\n",
19
+ "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
20
+ "from llama_index.vector_stores.chroma.base import ChromaVectorStore\n",
21
+ "from llama_index.core import StorageContext\n",
22
+ "from llama_index.core import ServiceContext\n",
23
+ "from llama_index.core import Document\n",
24
+ "\n",
25
+ "from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding\n",
26
+ "from llama_index.core import Settings\n",
27
+ "\n",
28
+ "import nest_asyncio\n",
29
+ "nest_asyncio.apply()\n",
30
+ "\n",
31
+ "import time"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "3e65dff6-77b6-4be8-8857-5cecf3a035bb",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "# load some documents\n",
42
+ "documents = SimpleDirectoryReader(input_files=[\n",
43
+ " \"../raw_documents/qna.txt\",\n",
44
+ " \"../raw_documents/HI Chapter Summary Version 1.3.pdf\",\n",
45
+ " \"../raw_documents/conversation_examples.txt\",\n",
46
+ " \"../raw_documents/HI_Knowledge_Base.pdf\",\n",
47
+ " \"../raw_documents/answers.txt\",\n",
48
+ " ]).load_data()\n",
49
+ "document = Document(text=\"\\n\\n\".join([doc.text for doc in documents]))"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "bd86b3f5-1dfc-4257-bd9c-86d34f02398d",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "# initialize client, setting path to save data\n",
60
+ "db = chromadb.PersistentClient(path=\"../models/chroma_db_advanced\")"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "f568ce7b-bcbf-455c-acf1-6c2cae129fed",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "# create collection\n",
71
+ "chroma_collection = db.get_or_create_collection(\"quickstart\")"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "ed0b018e-1982-46b2-b1b4-04f5c0ce8672",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "# assign chroma as the vector_store to the context\n",
82
+ "vector_store = ChromaVectorStore(chroma_collection=chroma_collection)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "id": "eb5edab2-30db-4bf7-96b5-4005d3161988",
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": []
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "0946b6ce-96ab-44de-ad75-e424a8429f67",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "Settings.llm = None\n",
101
+ "Settings.chunk_size = 1024\n",
102
+ "Settings.embed_model = \"local:../models/fine-tuned-embeddings-advanced\""
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "id": "b8c73a2c-1129-406a-8046-085afcaf9cbb",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "nodes = Settings.node_parser.get_nodes_from_documents(documents)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "75f1c76f-d3e5-4b69-818c-98865adb1457",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "len(nodes)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "adfe688f-95c0-477c-a9de-e9e77541a1d7",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": []
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "id": "dab4c6f3-ef67-4d90-b3d5-e290c5d1b6f4",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "storage_context = StorageContext.from_defaults(vector_store=vector_store)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "6a764113-ad7e-4674-aa57-ebbf405902a8",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "storage_context.docstore.add_documents(nodes)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "38e7c88d-6c45-4275-8293-d09b4b85a7cf",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "e492ed4a-23a3-47d6-8b50-51fb48b3aa05",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "start_time = time.time()"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "cbd11b89-9b83-4f08-bb30-160f750f2ffb",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "vector_index = VectorStoreIndex(nodes, storage_context=storage_context)"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "082a0d7e-b025-4db1-be2a-7a0b7bc453b9",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "vector_query_engine = vector_index.as_query_engine()"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "d3bd848d-9985-4a3d-bdc4-ec340cc69ef3",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "indexing_cost = time.time() - start_time\n",
199
+ "indexing_cost = indexing_cost / 60\n",
200
+ "print(f\"Indexing time: {indexing_cost:.1f} mins\")"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "3290e870-41d7-49c4-9c4f-cb16bd1f469e",
207
+ "metadata": {
208
+ "scrolled": true
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "response = vector_query_engine.query(\"Healthcare System in Singapore consists of?\")\n",
213
+ "response"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "131d907a-0677-4ad8-b3f7-6fc9b9c5d0a5",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": []
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "08fb2be5-3a44-4bb8-a9fc-61d7f03b7a35",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": []
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "id": "a7fc01f6-4738-415b-a96b-afd6cf8d789a",
235
+ "metadata": {},
236
+ "source": [
237
+ "### ChromaDB method - load vectorstore based on Chroma"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "id": "c1a42c35-5f57-423c-8fb7-7d18b3b466b5",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "import chromadb\n",
248
+ "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
249
+ "from llama_index.vector_stores.chroma.base import ChromaVectorStore\n",
250
+ "from llama_index.core import StorageContext\n",
251
+ "from llama_index.core import ServiceContext\n",
252
+ "from llama_index.core import Document\n",
253
+ "from llama_index.core import Settings\n",
254
+ "\n",
255
+ "from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding\n",
256
+ "from llama_index.llms.openai import OpenAI\n",
257
+ "from llama_index.core.memory import ChatMemoryBuffer\n",
258
+ "\n",
259
+ "import time"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "id": "72dd0ece-c72d-428a-89b4-9494d948c845",
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": []
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "id": "d38dc953-b923-4128-86a1-c8c6f69af0ed",
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "fine_tuned_path = \"local:../models/fine-tuned-embeddings-advanced\""
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "id": "4c83c613-2cfc-4871-9d07-c82f77a3bd5e",
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "llm = OpenAI(model=\"gpt-4-0125-preview\", temperature=0.0)"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "id": "0583e9b0-d977-488c-8331-46dfa749924c",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "Settings.llm = llm\n",
298
+ "Settings.embed_model = fine_tuned_path"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "id": "f994f440-f647-48b4-a517-46a79f7561e5",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": []
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "2159a2b6-494b-41b9-ac54-dd342bfb74ba",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "db = chromadb.PersistentClient(path=\"../models/chroma_db_advanced\")"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": null,
322
+ "id": "1b385644-b46e-4d13-88fa-9f4af39db405",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "chroma_collection = db.get_or_create_collection(\"quickstart\")"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "id": "93cb53d1-6b8c-4b2d-a839-53501c0d54b2",
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": [
336
+ "# assign chroma as the vector_store to the context\n",
337
+ "vector_store = ChromaVectorStore(chroma_collection=chroma_collection)\n",
338
+ "storage_context = StorageContext.from_defaults(vector_store=vector_store)"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": null,
344
+ "id": "c40d59e1-6d42-41f0-8c9b-70aa026093ae",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "# create your index\n",
349
+ "index = VectorStoreIndex.from_vector_store(\n",
350
+ " vector_store=vector_store,\n",
351
+ " storage_context=storage_context\n",
352
+ ")"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "id": "73ba6d06-ba69-4b5e-962a-9cf7d2dc4d94",
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": []
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "1a506940-c2b4-4d14-ad93-fd451331c582",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "system_content = (\"You are a helpful study assistant. \"\n",
371
+ " \"You do not respond as 'User' or pretend to be 'User'. \"\n",
372
+ " \"You only respond once as 'Assistant'.\"\n",
373
+ ")"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "id": "3f592848-8536-4b4d-b34a-adc32d043432",
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": [
383
+ "memory = ChatMemoryBuffer.from_defaults(token_limit=100_000)"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": null,
389
+ "id": "6c7df81a-fd2f-42bf-b09c-46d7750f7252",
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "chat_engine = index.as_chat_engine(\n",
394
+ " chat_mode=\"context\",\n",
395
+ " memory=memory,\n",
396
+ " system_prompt=system_content\n",
397
+ ")"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "id": "434f0caf-8b1f-40c6-b9ec-b039cd1ca612",
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "prompt = \"\"\"\n",
408
+ "Question: Which of the following is NOT a characteristic of medical expense insurance?\n",
409
+ "A. Pro ration factor and co-insurance.\n",
410
+ "B. Deductibles apply for all treatments.\n",
411
+ "C. Impose Sub- Limits.\n",
412
+ "D. Can be issued as a rider or stand-alone.\n",
413
+ "\"\"\""
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "id": "78abaf95-e52d-445c-9d8e-bc51efb20f06",
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": [
423
+ "res = chat_engine.chat(prompt)\n",
424
+ "print(res.response)"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "1e62303c-3a00-448f-ad93-15cb6cee1f24",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": []
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "id": "dad72f9f-7f86-407d-93be-f5724cb30d5c",
439
+ "metadata": {},
440
+ "outputs": [],
441
+ "source": [
442
+ "hi_engine = index.as_query_engine(\n",
443
+ " memory=memory,\n",
444
+ " system_prompt=system_content,\n",
445
+ " similarity_top_k=3,\n",
446
+ " streaming=True\n",
447
+ ")"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "id": "ab778a5d-d438-4f39-88f5-c67a1f1d575e",
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": []
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": null,
461
+ "id": "7bb7c21a-7461-40c1-87a7-4a1f92f70153",
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": [
465
+ "res = hi_engine.query(\"may I know what is the rationale?\")\n",
466
+ "print(res)"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": null,
472
+ "id": "874a39ce-e682-42fa-8085-646bacea6cdb",
473
+ "metadata": {},
474
+ "outputs": [],
475
+ "source": []
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "execution_count": null,
480
+ "id": "301e8270-783d-4942-a05f-9683ca96fbda",
481
+ "metadata": {},
482
+ "outputs": [],
483
+ "source": []
484
+ }
485
+ ],
486
+ "metadata": {
487
+ "kernelspec": {
488
+ "display_name": "Python 3 (ipykernel)",
489
+ "language": "python",
490
+ "name": "python3"
491
+ },
492
+ "language_info": {
493
+ "codemirror_mode": {
494
+ "name": "ipython",
495
+ "version": 3
496
+ },
497
+ "file_extension": ".py",
498
+ "mimetype": "text/x-python",
499
+ "name": "python",
500
+ "nbconvert_exporter": "python",
501
+ "pygments_lexer": "ipython3",
502
+ "version": "3.9.18"
503
+ }
504
+ },
505
+ "nbformat": 4,
506
+ "nbformat_minor": 5
507
+ }
notebooks/002_persisted-embedding-model.ipynb CHANGED
@@ -271,7 +271,7 @@
271
  "metadata": {},
272
  "outputs": [],
273
  "source": [
274
- "llm = OpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0.0)"
275
  ]
276
  },
277
  {
@@ -391,7 +391,23 @@
391
  "metadata": {},
392
  "outputs": [],
393
  "source": [
394
- "res = chat_engine.chat(\"what is the healthcare philosophy in singapore\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  "print(res.response)"
396
  ]
397
  },
@@ -413,7 +429,7 @@
413
  "hi_engine = index.as_query_engine(\n",
414
  " memory=memory,\n",
415
  " system_prompt=system_content,\n",
416
- " similarity_top_k=3,\n",
417
  " streaming=True\n",
418
  ")"
419
  ]
@@ -433,7 +449,7 @@
433
  "metadata": {},
434
  "outputs": [],
435
  "source": [
436
- "res = hi_engine.query(\"What is llama2?\")\n",
437
  "print(res)"
438
  ]
439
  },
 
271
  "metadata": {},
272
  "outputs": [],
273
  "source": [
274
+ "llm = OpenAI(model=\"gpt-4-0125-preview\", temperature=0.0)"
275
  ]
276
  },
277
  {
 
391
  "metadata": {},
392
  "outputs": [],
393
  "source": [
394
+ "prompt = \"\"\"\n",
395
+ "Question: Which of the following is NOT a characteristic of medical expense insurance?\n",
396
+ "A. Pro ration factor and co-insurance.\n",
397
+ "B. Deductibles apply for all treatments.\n",
398
+ "C. Impose Sub- Limits.\n",
399
+ "D. Can be issued as a rider or stand-alone.\n",
400
+ "\"\"\""
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "id": "9563515b-8a95-4dc8-a312-f57f9b59da86",
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": [
410
+ "res = chat_engine.chat(prompt)\n",
411
  "print(res.response)"
412
  ]
413
  },
 
429
  "hi_engine = index.as_query_engine(\n",
430
  " memory=memory,\n",
431
  " system_prompt=system_content,\n",
432
+ " similarity_top_k=10,\n",
433
  " streaming=True\n",
434
  ")"
435
  ]
 
449
  "metadata": {},
450
  "outputs": [],
451
  "source": [
452
+ "res = hi_engine.query(prompt)\n",
453
  "print(res)"
454
  ]
455
  },
raw_documents/answers.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d01aaa6a0000c46cf93b1572ad15464480260dbc8fa8dc718f4718a3ba7598
3
+ size 41317
raw_documents/conversation_examples.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd354c1b6691627a6598f124f76ef43d29a1c7108124d8d833180b8efbd207a4
3
+ size 47902
raw_documents/qna.txt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:da72ca56312ecb78d7cf6c9288b16a520baa2286136b4677cf09f36ee4f07b36
3
- size 56792
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62f7746092d2d52d8028fb13471427e220aae0ab411771eda56883e9bfdc75ce
3
+ size 75976
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  aiohttp==3.9.1
2
  aiosignal==1.3.1
3
  alembic==1.13.1
@@ -28,6 +29,7 @@ charset-normalizer==3.3.2
28
  chroma-hnswlib==0.7.3
29
  chromadb==0.4.22
30
  click==8.1.7
 
31
  coloredlogs==15.0.1
32
  comm==0.2.0
33
  contourpy==1.2.0
@@ -45,6 +47,7 @@ exceptiongroup==1.2.0
45
  executing==2.0.1
46
  Faker==22.0.0
47
  fastapi==0.109.0
 
48
  fastjsonschema==2.19.1
49
  favicon==0.7.0
50
  filelock==3.13.1
@@ -58,6 +61,7 @@ gitdb==4.0.11
58
  GitPython==3.1.40
59
  google-auth==2.27.0
60
  googleapis-common-protos==1.62.0
 
61
  greenlet==3.0.3
62
  grpcio==1.60.0
63
  h11==0.14.0
@@ -101,19 +105,28 @@ langchain==0.0.354
101
  langchain-community==0.0.8
102
  langchain-core==0.1.23
103
  langsmith==0.0.87
104
- llama-index==0.10.1
105
- llama-index-agent-openai==0.1.1
106
- llama-index-core==0.10.1
 
 
107
  llama-index-embeddings-huggingface==0.1.1
108
- llama-index-embeddings-openai==0.1.1
 
 
109
  llama-index-legacy==0.9.48
110
- llama-index-llms-openai==0.1.1
111
- llama-index-multi-modal-llms-openai==0.1.1
 
112
  llama-index-packs-auto-merging-retriever==0.1.2
113
- llama-index-program-openai==0.1.1
114
- llama-index-question-gen-openai==0.1.1
115
- llama-index-readers-file==0.1.2
 
 
116
  llama-index-vector-stores-chroma==0.1.1
 
 
117
  lxml==5.1.0
118
  Mako==1.3.0
119
  Markdown==3.5.1
@@ -176,7 +189,7 @@ pyarrow==14.0.2
176
  pyasn1==0.5.1
177
  pyasn1-modules==0.3.0
178
  pycparser==2.21
179
- pydantic==2.5.3
180
  pydantic_core==2.14.6
181
  pydeck==0.8.1b0
182
  Pygments==2.17.2
@@ -268,4 +281,4 @@ websockets==12.0
268
  widgetsnbextension==4.0.9
269
  wrapt==1.16.0
270
  yarl==1.9.4
271
- zipp==3.17.0
 
1
+ aenum==3.1.15
2
  aiohttp==3.9.1
3
  aiosignal==1.3.1
4
  alembic==1.13.1
 
29
  chroma-hnswlib==0.7.3
30
  chromadb==0.4.22
31
  click==8.1.7
32
+ cohere==4.49
33
  coloredlogs==15.0.1
34
  comm==0.2.0
35
  contourpy==1.2.0
 
47
  executing==2.0.1
48
  Faker==22.0.0
49
  fastapi==0.109.0
50
+ fastavro==1.9.1
51
  fastjsonschema==2.19.1
52
  favicon==0.7.0
53
  filelock==3.13.1
 
61
  GitPython==3.1.40
62
  google-auth==2.27.0
63
  googleapis-common-protos==1.62.0
64
+ gradientai==1.7.0
65
  greenlet==3.0.3
66
  grpcio==1.60.0
67
  h11==0.14.0
 
105
  langchain-community==0.0.8
106
  langchain-core==0.1.23
107
  langsmith==0.0.87
108
+ llama-index==0.10.12
109
+ llama-index-agent-openai==0.1.5
110
+ llama-index-cli==0.1.5
111
+ llama-index-core==0.10.12
112
+ llama-index-embeddings-adapter==0.1.3
113
  llama-index-embeddings-huggingface==0.1.1
114
+ llama-index-embeddings-openai==0.1.6
115
+ llama-index-finetuning==0.1.4
116
+ llama-index-indices-managed-llama-cloud==0.1.3
117
  llama-index-legacy==0.9.48
118
+ llama-index-llms-gradient==0.1.2
119
+ llama-index-llms-openai==0.1.6
120
+ llama-index-multi-modal-llms-openai==0.1.4
121
  llama-index-packs-auto-merging-retriever==0.1.2
122
+ llama-index-postprocessor-cohere-rerank==0.1.2
123
+ llama-index-program-openai==0.1.4
124
+ llama-index-question-gen-openai==0.1.3
125
+ llama-index-readers-file==0.1.5
126
+ llama-index-readers-llama-parse==0.1.3
127
  llama-index-vector-stores-chroma==0.1.1
128
+ llama-parse==0.3.4
129
+ llamaindex-py-client==0.1.13
130
  lxml==5.1.0
131
  Mako==1.3.0
132
  Markdown==3.5.1
 
189
  pyasn1==0.5.1
190
  pyasn1-modules==0.3.0
191
  pycparser==2.21
192
+ pydantic==1.10.14
193
  pydantic_core==2.14.6
194
  pydeck==0.8.1b0
195
  Pygments==2.17.2
 
281
  widgetsnbextension==4.0.9
282
  wrapt==1.16.0
283
  yarl==1.9.4
284
+ zipp==3.17.0
streamlit_app.py CHANGED
@@ -7,6 +7,7 @@ import base64
7
  from io import BytesIO
8
  import sqlite3
9
  import uuid
 
10
 
11
  import chromadb
12
  from llama_index.core import (
@@ -39,14 +40,14 @@ nest_asyncio.apply()
39
  st.set_page_config(page_title="πŸ»πŸ“š Study Bear 🍯")
40
  openai_api = os.getenv("OPENAI_API_KEY")
41
 
42
- # "./raw_documents/HI_Knowledge_Base.pdf"
43
- image_prompt = False
44
- input_files = ["./raw_documents/HI Chapter Summary Version 1.3.pdf",
45
- "./raw_documents/qna.txt"]
46
- embedding_model = "BAAI/bge-small-en-v1.5"
47
- persisted_vector_db = "./models/chroma_db"
48
- fine_tuned_path = "local:models/fine-tuned-embeddings"
49
- questionaire_db_path = "./database/mock_qna.sqlite"
50
 
51
  data_df = pd.DataFrame(
52
  {
@@ -109,6 +110,9 @@ if "init" not in st.session_state.keys():
109
  st.session_state.init = {"warm_started": "No"}
110
  st.session_state.feedback = False
111
 
 
 
 
112
  # Store LLM generated responses
113
  if "messages" not in st.session_state.keys():
114
  st.session_state.messages = [{"role": "assistant",
@@ -341,19 +345,19 @@ if prompt := st.chat_input(disabled=not openai_api):
341
  # Retrieve text prompt from image submission
342
  if prompt is None and \
343
  st.session_state.messages[-1]["role"] == "admin":
344
- image_prompt = True
345
  prompt = st.session_state.messages[-1]["content"]
346
 
347
  # Generate a new response if last message is not from assistant
348
  if st.session_state.messages[-1]["role"] != "assistant":
349
  with st.chat_message("assistant", avatar=bear_img_path):
350
  with st.spinner("πŸ§ΈπŸ’€ Thinking... πŸ»πŸ’­"):
351
- if image_prompt:
352
  response = generate_llm_response(
353
  prompt,
354
  tool_choice="health_insurance_textbook_query_engine"
355
  )
356
- image_prompt = False
357
  else:
358
  response = generate_llm_response(prompt, tool_choice="auto")
359
  placeholder = st.empty()
 
7
  from io import BytesIO
8
  import sqlite3
9
  import uuid
10
+ import yaml
11
 
12
  import chromadb
13
  from llama_index.core import (
 
40
  st.set_page_config(page_title="πŸ»πŸ“š Study Bear 🍯")
41
  openai_api = os.getenv("OPENAI_API_KEY")
42
 
43
+ with open("./config/model_config.yml", "r") as file_reader:
44
+ model_config = yaml.safe_load(file_reader)
45
+
46
+ input_files = model_config["input_data"]["source"]
47
+ embedding_model = model_config["embeddings"]["embedding_base_model"]
48
+ fine_tuned_path = model_config["embeddings"]["fine_tuned_embedding_model"]
49
+ persisted_vector_db = model_config["vector_store"]["persisted_path"]
50
+ questionaire_db_path = model_config["questionaire_data"]["db_path"]
51
 
52
  data_df = pd.DataFrame(
53
  {
 
110
  st.session_state.init = {"warm_started": "No"}
111
  st.session_state.feedback = False
112
 
113
+ if "image_prompt" not in st.session_state.keys():
114
+ st.session_state.image_prompt = False
115
+
116
  # Store LLM generated responses
117
  if "messages" not in st.session_state.keys():
118
  st.session_state.messages = [{"role": "assistant",
 
345
  # Retrieve text prompt from image submission
346
  if prompt is None and \
347
  st.session_state.messages[-1]["role"] == "admin":
348
+ st.session_state.image_prompt = True
349
  prompt = st.session_state.messages[-1]["content"]
350
 
351
  # Generate a new response if last message is not from assistant
352
  if st.session_state.messages[-1]["role"] != "assistant":
353
  with st.chat_message("assistant", avatar=bear_img_path):
354
  with st.spinner("πŸ§ΈπŸ’€ Thinking... πŸ»πŸ’­"):
355
+ if st.session_state.image_prompt:
356
  response = generate_llm_response(
357
  prompt,
358
  tool_choice="health_insurance_textbook_query_engine"
359
  )
360
+ st.session_state.image_prompt = False
361
  else:
362
  response = generate_llm_response(prompt, tool_choice="auto")
363
  placeholder = st.empty()
vision_api.py CHANGED
@@ -9,6 +9,14 @@ def get_transcribed_text(base64_image):
9
  "Content-Type": "application/json",
10
  "Authorization": f"Bearer {OPENAI_API_KEY}"
11
  }
 
 
 
 
 
 
 
 
12
 
13
  payload = {
14
  "model": "gpt-4-vision-preview",
@@ -18,7 +26,7 @@ def get_transcribed_text(base64_image):
18
  "content": [
19
  {
20
  "type": "text",
21
- "text": "transcribe the image into text for me."
22
  },
23
  {
24
  "type": "image_url",
 
9
  "Content-Type": "application/json",
10
  "Authorization": f"Bearer {OPENAI_API_KEY}"
11
  }
12
+ image_prompt = (
13
+ "Understand and interpret the image properly, there could be "
14
+ "handwritten notes or scribbles beside the electronic text. "
15
+ "Once you have sufficient understanding of the image, "
16
+ "transcribed them into text. If the content is a question, "
17
+ "convert the question into text."
18
+ )
19
+ print(image_prompt)
20
 
21
  payload = {
22
  "model": "gpt-4-vision-preview",
 
26
  "content": [
27
  {
28
  "type": "text",
29
+ "text": image_prompt
30
  },
31
  {
32
  "type": "image_url",