JPBianchi commited on
Commit
30ffb9e
1 Parent(s): a0cb228

temp before HF pull

Browse files
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Python 3",
3
+ "image": "mcr.microsoft.com/devcontainers/python:3.10",
4
+
5
+ // Features to add to the dev container. More info: https://containers.dev/features.
6
+ //"features": {}
7
+ // Configure tool-specific properties.
8
+ "customizations": {
9
+ // Configure properties specific to VS Code.
10
+ "vscode": {
11
+ "settings": {"terminal.integrated.shell.linux": "/bin/bash"},
12
+ "extensions": [
13
+ "ms-toolsai.jupyter"
14
+ ]
15
+ }
16
+ },
17
+ "forwardPorts": [8501, 8888],
18
+ "portsAttributes": {
19
+ "8501": {
20
+ "label": "Streamlit App",
21
+ "onAutoForward": "openBrowser"
22
+ },
23
+ "8888": {
24
+ "label": "Jupyter Notebook",
25
+ "onAutoForward": "openBrowser"
26
+ }
27
+ },
28
+ "postCreateCommand": "pip install -r requirements.txt"
29
+ }
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/impact_theory_data.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large Files
2
+ # models/
3
+ eval_results/
4
+ # models/all-mpnet*
5
+ # models/finetuned-all-MiniLM*
6
+ # models/finetuned-WhereIsAI-UAE*
7
+ models/*
8
+ # !models/finetuned-all-mpnet-base-v2-300
9
+
10
+ data/*.parquet
11
+ .DS_Store
12
+ secrets.toml
13
+ TODO.md
14
+
15
+ assets/*
16
+ !assets/it_tom_bilyeu.png
17
+
18
+ # Byte-compiled / optimized / DLL files
19
+ __pycache__/
20
+ *.py[cod]
21
+ *$py.class
22
+ *copy*
23
+ # C extensions
24
+ *.so
25
+
26
+
27
+
28
+ # Distribution / packaging
29
+ .Python
30
+ build/
31
+ develop-eggs/
32
+ dist/
33
+ downloads/
34
+ eggs/
35
+ .eggs/
36
+ lib/
37
+ lib64/
38
+ parts/
39
+ sdist/
40
+ var/
41
+ wheels/
42
+ share/python-wheels/
43
+ *.egg-info/
44
+ .installed.cfg
45
+ *.egg
46
+ MANIFEST
47
+
48
+ # PyInstaller
49
+ # Usually these files are written by a python script from a template
50
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
51
+ *.manifest
52
+ *.spec
53
+
54
+ # Installer logs
55
+ pip-log.txt
56
+ pip-delete-this-directory.txt
57
+
58
+ # Unit test / coverage reports
59
+ htmlcov/
60
+ .tox/
61
+ .nox/
62
+ .coverage
63
+ .coverage.*
64
+ .cache
65
+ nosetests.xml
66
+ coverage.xml
67
+ *.cover
68
+ *.py,cover
69
+ .hypothesis/
70
+ .pytest_cache/
71
+ cover/
72
+
73
+ # Translations
74
+ *.mo
75
+ *.pot
76
+
77
+ # Django stuff:
78
+ *.log
79
+ local_settings.py
80
+ db.sqlite3
81
+ db.sqlite3-journal
82
+
83
+ # Flask stuff:
84
+ instance/
85
+ .webassets-cache
86
+
87
+ # Scrapy stuff:
88
+ .scrapy
89
+
90
+ # Sphinx documentation
91
+ docs/_build/
92
+
93
+ # PyBuilder
94
+ .pybuilder/
95
+ target/
96
+
97
+ # Jupyter Notebook
98
+ .ipynb_checkpoints
99
+
100
+ # IPython
101
+ profile_default/
102
+ ipython_config.py
103
+
104
+ # pyenv
105
+ # For a library or package, you might want to ignore these files since the code is
106
+ # intended to run in multiple environments; otherwise, check them in:
107
+ # .python-version
108
+
109
+ # pipenv
110
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
111
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
112
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
113
+ # install all needed dependencies.
114
+ #Pipfile.lock
115
+
116
+ # poetry
117
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
118
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
119
+ # commonly ignored for libraries.
120
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
121
+ #poetry.lock
122
+
123
+ # pdm
124
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
125
+ #pdm.lock
126
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
127
+ # in version control.
128
+ # https://pdm.fming.dev/#use-with-ide
129
+ .pdm.toml
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # SageMath parsed files
139
+ *.sage.py
140
+
141
+ # Environments
142
+ .env
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
.streamlit/config.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base="dark"
3
+ primaryColor="purple" # border of textboxes !!??
4
+ #primaryColor="#2d59b3"
5
+
6
+ backgroundColor="#000000"
7
+ secondaryBackgroundColor= "#0e404d" # should be identical to blue in banner # "#2d59b3" light blue
8
+ textColor="#FFFFFF"
9
+ font="sans serif"
10
+
11
+
.vscode/launch.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Python: Current File",
6
+ "type": "python",
7
+ "request": "launch",
8
+ "program": "${file}",
9
+ "console": "integratedTerminal",
10
+ "justMyCode": false
11
+ }
12
+ ]
13
+ }
README.md CHANGED
@@ -10,3 +10,31 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+ ---
14
+
15
+ See the app @ [jpb-vectorsearch.streamlit.app](https://jpb-vectorsearch.streamlit.app/)
16
+
17
+ Beware, sometimes, the online app crashes... especially with the metrics.
18
+
19
+ <p align="left">
20
+ <img src="assets/screenshot_frontpage_with_finetune.png" width=800/>
21
+ </p>
22
+
23
+ <!-- <p align="center">
24
+ <img src="assets/screenshot_frontpage_online.png"/>
25
+ </p> -->
26
+
27
+ ## Activity on Modal backend during finetuning
28
+
29
+ <p align="left">
30
+ <img src="assets/modal_finetuning1.png" width=800/>
31
+ </p>
32
+
33
+ <p align="left">
34
+ <img src="assets/modal_finetuning2.png" width=800/>
35
+ </p>
36
+
37
+ <p align="left">
38
+ <img src="assets/modal_finetuning_activity.png" width=800/>
39
+ </p>
40
+
app.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from tiktoken import get_encoding, encoding_for_model
3
+ from weaviate_interface import WeaviateClient, WhereFilter
4
+ from sentence_transformers import SentenceTransformer
5
+ from prompt_templates import question_answering_prompt_series, question_answering_system
6
+ from openai_interface import GPT_Turbo
7
+ from app_features import (convert_seconds, generate_prompt_series, search_result,
8
+ validate_token_threshold, load_content_cache, load_data,
9
+ expand_content)
10
+ from retrieval_evaluation import execute_evaluation, calc_hit_rate_scores
11
+ from llama_index.finetuning import EmbeddingQAFinetuneDataset
12
+ from weaviate_interface import WeaviateClient
13
+ from openai import BadRequestError
14
+ from reranker import ReRanker
15
+ from loguru import logger
16
+ import streamlit as st
17
+ from streamlit_option_menu import option_menu
18
+ import hydralit_components as hc
19
+ import sys
20
+ import json
21
+ import os, time, requests, re
22
+ from datetime import timedelta
23
+ import pathlib
24
+ import gdown
25
+ import tempfile
26
+ import base64
27
+ import shutil
28
+
29
+ def get_base64_of_bin_file(bin_file):
30
+ with open(bin_file, 'rb') as file:
31
+ data = file.read()
32
+ return base64.b64encode(data).decode()
33
+
34
+ from dotenv import load_dotenv, find_dotenv
35
+ load_dotenv(find_dotenv('env'), override=True)
36
+
37
+ # I use a key that I increment each time I want to change a text_input
38
+ if 'key' not in st.session_state:
39
+ st.session_state.key = 0
40
+ # key = st.session_state['key']
41
+
42
+ if not pathlib.Path('models').exists():
43
+ os.mkdir('models')
44
+
45
+ # I should cache these things but no time left
46
+
47
+ # I put a file local.txt in my desktop models folder to find out if it's running online
48
+ we_are_online = not pathlib.Path("models/local.txt").exists()
49
+ we_are_not_online = not we_are_online
50
+
51
+ golden_dataset = EmbeddingQAFinetuneDataset.from_json("data/golden_100.json")
52
+
53
+ # shutil.rmtree("models/models") # remove it - I wanted to clear the space on streamlit online
54
+
55
+ ## PAGE CONFIGURATION
56
+ st.set_page_config(page_title="Ask Impact Theory",
57
+ page_icon="assets/impact-theory-logo-only.png",
58
+ layout="wide",
59
+ initial_sidebar_state="collapsed",
60
+ menu_items={'Report a bug': "https://www.extremelycoolapp.com/bug"})
61
+
62
+
63
+ image = "https://is2-ssl.mzstatic.com/image/thumb/Music122/v4/bd/34/82/bd348260-314c-5898-26c0-bef2e0388ebe/source/1200x1200bb.png"
64
+
65
+
66
+ def add_bg_from_local(image_file):
67
+ bin_str = get_base64_of_bin_file(image_file)
68
+ page_bg_img = f'''
69
+ <style>
70
+ .stApp {{
71
+ background-image: url("data:image/png;base64,{bin_str}");
72
+ background-size: 100% auto;
73
+ background-repeat: no-repeat;
74
+ background-attachment: fixed;
75
+ }}
76
+ </style>
77
+ '''
78
+
79
+ st.markdown(page_bg_img, unsafe_allow_html=True)
80
+
81
+ # COMMENT: I tried to create a dropdown menu but it's harder than it looks, so I gave up
82
+ # https://discuss.streamlit.io/t/streamlit-option-menu-is-a-simple-streamlit-component-that-allows-users-to-select-a-single-item-from-a-list-of-options-in-a-menu/20514
83
+ # not great, but it works
84
+ # selected = option_menu("About", ["Improvements","This"], #"Main Menu", ["Home", 'Settings'],
85
+ # icons=['house', 'gear'],
86
+ # menu_icon="cast",
87
+ # default_index=1)
88
+
89
+ # # Custom HTML/CSS for the banner
90
+ # base64_img = get_base64_of_bin_file("assets/it_tom_bilyeu.png")
91
+ # banner_menu_html = f"""
92
+ # <div class="banner">
93
+ # <img src= "data:image/png;base64,{base64_img}" alt="Banner Image">
94
+ # </div>
95
+ # <style>
96
+ # .banner {{
97
+ # width: 100%;
98
+ # height: auto;
99
+ # overflow: hidden;
100
+ # display: flex;
101
+ # justify-content: center;
102
+ # }}
103
+ # .banner img {{
104
+ # width: 130%;
105
+ # height: auto;
106
+ # object-fit: contain;
107
+ # }}
108
+ # </style>
109
+ # """
110
+ # st.components.v1.html(banner_menu_html)
111
+
112
+
113
+ # specify the primary menu definition
114
+ # it gives a vertical menu inside a navigation bar !!!
115
+ # menu_data = [
116
+ # {'icon': "far fa-copy", 'label':"Left End"},
117
+ # {'id':'Copy','icon':"🐙",'label':"Copy"},
118
+ # {'icon': "far fa-chart-bar", 'label':"Chart"},#no tooltip message
119
+ # {'icon': "far fa-address-book", 'label':"Book"},
120
+ # {'id':' Crazy return value 💀','icon': "💀", 'label':"Calendar"},
121
+ # {'icon': "far fa-clone", 'label':"Component"},
122
+ # {'icon': "fas fa-tachometer-alt", 'label':"Dashboard",'ttip':"I'm the Dashboard tooltip!"}, #can add a tooltip message
123
+ # {'icon': "far fa-copy", 'label':"Right End"},
124
+ # ]
125
+ # # we can override any part of the primary colors of the menu
126
+ # over_theme = {'txc_inactive': '#FFFFFF','menu_background':'red','txc_active':'yellow','option_active':'blue'}
127
+ # # over_theme = {'txc_inactive': '#FFFFFF'}
128
+ # menu_id = hc.nav_bar(menu_definition=menu_data,
129
+ # home_name='Home',
130
+ # override_theme=over_theme)
131
+ #get the id of the menu item clicked
132
+ # st.info(f"{menu_id=}")
133
+ ## RERANKER
134
+ reranker = ReRanker('cross-encoder/ms-marco-MiniLM-L-6-v2')
135
+ ## ENCODING --> tiktoken library
136
+ model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
137
+ model_nameGPT = model_ids[1]
138
+ encoding = encoding_for_model(model_nameGPT)
139
+ # = get_encoding('gpt-3.5-turbo-0613')
140
+ ##############
141
+ data_path = './data/impact_theory_data.json'
142
+ cache_path = 'data/impact_theory_cache.parquet'
143
+ data = load_data(data_path)
144
+ cache = None # load_content_cache(cache_path)
145
+
146
+ try:
147
+ # st.write("Loading secrets from secrets.toml")
148
+ Wapi_key = st.secrets['secrets']['WEAVIATE_API_KEY']
149
+ url = st.secrets['secrets']['WEAVIATE_ENDPOINT']
150
+ openai_api_key = st.secrets['secrets']['OPENAI_API_KEY']
151
+
152
+ hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
153
+ hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
154
+ # st.write("Secrets loaded from secrets.toml")
155
+ # st.write("HF_TOKEN", hf_token)
156
+ except:
157
+ st.write("Loading secrets from environment variables")
158
+ api_key = os.environ['WEAVIATE_API_KEY']
159
+ url = os.environ['WEAVIATE_ENDPOINT']
160
+ openai_api_key = os.environ['OPENAI_API_KEY']
161
+
162
+ hf_token = os.environ['LLAMA2_ENDPOINT_HF_TOKEN_chris']
163
+ hf_endpoint = os.environ['LLAMA2_ENDPOINT_UPLIMIT']
164
+ #%%
165
+ # model_default = 'sentence-transformers/all-mpnet-base-v2'
166
+ model_default = 'models/finetuned-all-mpnet-base-v2-300' if we_are_not_online \
167
+ else 'sentence-transformers/all-mpnet-base-v2'
168
+
169
+ available_models = ['sentence-transformers/all-mpnet-base-v2',
170
+ 'sentence-transformers/all-MiniLM-L6-v2',
171
+ 'models/finetuned-all-mpnet-base-v2-300']
172
+
173
+ #%%
174
+ models_urls = {'models/finetuned-all-mpnet-base-v2-300': "https://drive.google.com/drive/folders/1asJ37-AUv5nytLtH6hp6_bVV3_cZOXfj"}
175
+
176
+ def download_model_from_Gdrive(model_name_or_path, model_full_path):
177
+ print("Downloading model from Google Drive")
178
+ st.write("Downloading model from Google Drive")
179
+ assert model_name_or_path in models_urls, f"Model {model_name_or_path} not found in models_urls"
180
+ url = models_urls[model_name_or_path]
181
+ gdown.download_folder(url, output=model_full_path, quiet=False, use_cookies=False)
182
+ print("Model downloaded and saved to models folder")
183
+ # st.write("Model downloaded")
184
+
185
+ def download_model(model_name_or_path, model_full_path):
186
+
187
+ if model_name_or_path.startswith("models/"):
188
+ download_model_from_Gdrive(model_name_or_path, model_full_path)
189
+ print(f"Model {model_full_path} downloaded")
190
+ models_urls[model_name_or_path] = model_full_path
191
+ # st.sidebar.write(f"Model {model_full_path} downloaded")
192
+
193
+ elif model_name_or_path.startswith("sentence-transformers/"):
194
+ st.sidebar.write(f"Downloading Sentence Transformer model {model_name_or_path}")
195
+ model = SentenceTransformer(model_name_or_path) # HF looks into its own models folder/path
196
+ models_urls[model_name_or_path] = model_full_path
197
+ # st.sidebar.write(f"Model {model_name_or_path} downloaded")
198
+ model.save(model_full_path)
199
+ # st.sidebar.write(f"Model {model_name_or_path} saved to {model_full_path}")
200
+
201
+ # if 'modelspath' not in st.session_state:
202
+ # st.session_state['modelspath'] = None
203
+ # if st.session_state.modelspath is None:
204
+ # # let's create a temp folder on the first run
205
+ # persistent_dir = pathlib.Path("path/to/persistent_dir")
206
+ # persistent_dir.mkdir(parents=True, exist_ok=True)
207
+ # with tempfile.TemporaryDirectory() as temp_dir:
208
+ # st.session_state.modelspath = temp_dir
209
+ # print(f"Temporary directory created at {temp_dir}")
210
+ # # the temp folder disappears with the context, but not the one we've created manually
211
+ # else:
212
+ # temp_dir = st.session_state.modelspath
213
+ # print(f"Temporary directory already exists at {temp_dir}")
214
+ # # st.write(os.listdir(temp_dir))
215
+
216
+ #%%
217
+ # for streamlit online, we must download the model from google drive
218
+ # because github LFS doesn't work on forked repos
219
+ def check_model(model_name_or_path):
220
+
221
+ model_path = pathlib.Path(model_name_or_path)
222
+ model_full_path = str(pathlib.Path("models") / model_path) # this creates a models folder inside /models
223
+ model_full_path = model_full_path.replace("sentence-transformers/", "models/") # all are saved in models folder
224
+
225
+ if pathlib.Path(model_full_path).exists():
226
+ # let's use the model that's already there
227
+ print(f"Model {model_full_path} already exists")
228
+
229
+
230
+ # but delete everything else in we are online because
231
+ # streamlit online has limited space (and will shut down the app if it's full)
232
+ if we_are_online:
233
+ # st.sidebar.write(f"Model {model_full_path} already exists")
234
+ # st.sidebar.write(f"Deleting other models")
235
+ dirs = os.listdir("models/models")
236
+ # we get only the folder name, not the full path
237
+ dirs.remove(model_full_path.split('/')[-1])
238
+ for p in dirs:
239
+ dirpath = pathlib.Path("models/models") / p
240
+ if dirpath.is_dir():
241
+ shutil.rmtree(dirpath)
242
+ else:
243
+
244
+ if we_are_online:
245
+ # space issues on streamlit online, let's not leave anything behind
246
+ # and redownload the model eveery time
247
+ print("Deleting models/models folder")
248
+ if pathlib.Path('models/models').exists():
249
+ shutil.rmtree("models/models") # make room, if other models are there
250
+ # st.sidebar.write(f"models/models folder deleted")
251
+
252
+ download_model(model_name_or_path, model_full_path)
253
+
254
+ return model_full_path
255
+
256
+ #%% instantiate Weaviate client
257
+ def get_weaviate_client(api_key, url, model_name_or_path, openai_api_key):
258
+ client = WeaviateClient(api_key, url,
259
+ model_name_or_path=model_name_or_path,
260
+ openai_api_key=openai_api_key)
261
+ client.display_properties.append('summary')
262
+ available_classes = sorted(client.show_classes())
263
+ # st.write(f"Available classes: {available_classes}")
264
+ # st.write(f"Available classes type: {type(available_classes)}")
265
+ logger.info(available_classes)
266
+ return client, available_classes
267
+
268
+
269
+ ##############
270
+ # data = load_data(data_path)
271
+ # guests list for sidebar
272
+ guest_list = sorted(list(set([d['guest'] for d in data])))
273
+
274
+ def main():
275
+
276
+ with st.sidebar:
277
+ # moved it to main area
278
+ # guest = st.selectbox('Select Guest',
279
+ # options=guest_list,
280
+ # index=None,
281
+ # placeholder='Select Guest')
282
+ _, center, _ = st.columns([3, 5, 3])
283
+ with center:
284
+ st.text("Search Lab")
285
+
286
+ _, center, _ = st.columns([2, 5, 3])
287
+ with center:
288
+ if we_are_online:
289
+ st.text("Running ONLINE")
290
+ st.text("(UNSTABLE)")
291
+ else:
292
+ st.text("Running OFFLINE")
293
+ st.write("----------")
294
+
295
+ alpha_input = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05)
296
+ retrieval_limit = st.slider(label='Hybrid Search Results', min_value=10, max_value=300, value=10, step=10)
297
+
298
+ hybrid_filter = st.toggle('Filter Guest', True) # i.e. look only at guests' data
299
+
300
+ rerank = st.toggle('Use Reranker', True)
301
+ if rerank:
302
+ reranker_topk = st.slider(label='Reranker Top K',min_value=1, max_value=5, value=3, step=1)
303
+ else:
304
+ # needed to not fill the LLM with too many responses (> context size)
305
+ # we could make it dependent on the model
306
+ reranker_topk = 3
307
+
308
+ rag_it = st.toggle('RAG it', True)
309
+ if rag_it:
310
+ st.sidebar.write(f"Using LLM '{model_nameGPT}'")
311
+ llm_temperature = st.slider(label='LLM T˚', min_value=0.0, max_value=2.0, value=0.01, step=0.10 )
312
+
313
+ model_name_or_path = st.selectbox(label='Model Name:', options=available_models,
314
+ index=available_models.index(model_default),
315
+ placeholder='Select Model')
316
+
317
+ st.write("Experimental and time limited 2'")
318
+ finetune_model = st.toggle('Finetune on Modal A100 GPU', False)
319
+ if finetune_model:
320
+ from finetune_backend import finetune
321
+ if 'finetuned' in model_name_or_path:
322
+ st.write("Model already finetuned")
323
+ elif model_name_or_path.startswith("models/"):
324
+ st.write("Sentence Transformers models only!")
325
+ else:
326
+ try:
327
+ if 'finetuned' in finetune_model:
328
+ st.write("Model already finetuned")
329
+ else:
330
+ model_path = finetune(model_name_or_path, savemodel=True, outpath='models')
331
+ if model_path is not None:
332
+ if finetune_model.split('/')[-1] not in model_path:
333
+ st.write(model_path) # a warning from finetuning in this case
334
+ elif model_path not in available_models:
335
+ # finetuning generated a model, let's add it
336
+ available_models.append(model_path)
337
+ st.write("Model saved!")
338
+ except Exception:
339
+ st.write("Model not found on HF or error")
340
+
341
+ model_name_or_path = check_model(model_name_or_path)
342
+ client, available_classes = get_weaviate_client(Wapi_key, url, model_name_or_path, openai_api_key)
343
+
344
+ start_class = 'Impact_theory_all_mpnet_base_v2_finetuned'
345
+
346
+ class_name = st.selectbox(
347
+ label='Class Name:',
348
+ options=available_classes,
349
+ index=available_classes.index(start_class),
350
+ placeholder='Select Class Name'
351
+ )
352
+
353
+ st.write("----------")
354
+
355
+ c1,c2 = st.columns([8,1])
356
+ with c1:
357
+ show_metrics = st.toggle('Show Metrics on Golden set', False)
358
+ if show_metrics:
359
+ # _, center, _ = st.columns([3, 5, 3])
360
+ # with center:
361
+ # st.text("Metrics")
362
+ with c2:
363
+ with st.spinner(''):
364
+ metrics = execute_evaluation(golden_dataset, class_name, client, alpha=alpha_input)
365
+ if show_metrics:
366
+ kw_hit_rate = metrics['kw_hit_rate']
367
+ kw_mrr = metrics['kw_mrr']
368
+ hybrid_hit_rate = metrics['hybrid_hit_rate']
369
+ vector_hit_rate = metrics['vector_hit_rate']
370
+ vector_mrr = metrics['vector_mrr']
371
+ total_misses = metrics['total_misses']
372
+
373
+ st.text(f"KW hit rate: {kw_hit_rate}")
374
+ st.text(f"Vector hit rate: {vector_hit_rate}")
375
+ st.text(f"Hybrid hit rate: {hybrid_hit_rate}")
376
+ st.text(f"Hybrid MRR: {vector_mrr}")
377
+ st.text(f"Total misses: {total_misses}")
378
+
379
+ st.write("----------")
380
+
381
+ st.title("Chat with the Impact Theory podcasts!")
382
+ # st.image('./assets/impact-theory-logo.png', width=400)
383
+ st.image('assets/it_tom_bilyeu.png', use_column_width=True)
384
+ # st.subheader(f"Chat with the Impact Theory podcast: ")
385
+ st.write('\n')
386
+ # st.stop()
387
+
388
+
389
+ st.write("\u21D0 Open the sidebar to change Search settings \n ") # https://home.unicode.org also 21E0, 21B0 B2 D0
390
+ guest = st.selectbox('Select A Guest',
391
+ options=guest_list,
392
+ index=None,
393
+ placeholder='Select Guest')
394
+
395
+
396
+ col1, col2 = st.columns([7,3])
397
+ with col1:
398
+ if guest is None:
399
+ msg = f'Select a guest before asking your question:'
400
+ else:
401
+ msg = f'Enter your question about {guest}:'
402
+
403
+ textbox = st.empty()
404
+ # best solution I found to be able to change the text inside a text_input box afterwards, using a key
405
+ query = textbox.text_input(msg,
406
+ value="",
407
+ placeholder="You can refer to the guest with pronoun or drop the question mark",
408
+ key=st.session_state.key)
409
+
410
+ # st.write(f"Guest = {guest}")
411
+ # st.write(f"key = {st.session_state.key}")
412
+
413
+ st.write('\n\n\n\n\n')
414
+
415
+ reworded_query = {'changed': False, 'status': 'error'} # at start, the query is empty
416
+ valid_response = [] # at start, the query is empty, so prevent the search
417
+
418
+ if query:
419
+
420
+ if guest is None:
421
+ st.session_state.key += 1
422
+ query = textbox.text_input(msg,
423
+ value="",
424
+ placeholder="YOU MUST SELECT A GUEST BEFORE ASKING A QUESTION",
425
+ key=st.session_state.key)
426
+ # st.write(f"key = {st.session_state.key}")
427
+ st.stop()
428
+ else:
429
+ # st.write(f'It looks like you selected {guest} as a filter (It is ignored for now).')
430
+
431
+ with col2:
432
+ # let's add a nice pulse bar while generating the response
433
+ with hc.HyLoader('', hc.Loaders.pulse_bars, primary_color= 'red', height=50): #"#0e404d" for image green
434
+ # with st.spinner('Generating Response...'):
435
+
436
+ with col1:
437
+
438
+ # let's use Llama2 here
439
+ reworded_query = reword_query(query, guest,
440
+ model_name='llama2-13b-chat')
441
+ query = reworded_query['rewritten_question']
442
+
443
+ # we can arrive here only if a guest was selected
444
+ where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
445
+ if hybrid_filter else None
446
+
447
+ hybrid_response = client.hybrid_search(query,
448
+ class_name,
449
+ # properties=['content'], #['title', 'summary', 'content'],
450
+ alpha=alpha_input,
451
+ display_properties=client.display_properties,
452
+ where_filter=where_filter,
453
+ limit=retrieval_limit)
454
+ response = hybrid_response
455
+
456
+ if rerank:
457
+ # rerank results with cross encoder
458
+ ranked_response = reranker.rerank(response, query,
459
+ apply_sigmoid=True, # score between 0 and 1
460
+ top_k=reranker_topk)
461
+ logger.info(ranked_response)
462
+ expanded_response = expand_content(ranked_response, cache,
463
+ content_key='doc_id',
464
+ create_new_list=True)
465
+
466
+ response = expanded_response
467
+
468
+ # make sure token count < threshold
469
+ token_threshold = 8000 if model_nameGPT == model_ids[0] else 3500
470
+ valid_response = validate_token_threshold(response,
471
+ question_answering_prompt_series,
472
+ query=query,
473
+ tokenizer= encoding,# variable from ENCODING,
474
+ token_threshold=token_threshold,
475
+ verbose=True)
476
+ # st.write(f"Number of results: {len(valid_response)}")
477
+
478
+
479
+ # I jump out of col1 to get all page width, so need to retest query
480
+ if query is not None and reworded_query['status'] != 'error':
481
+ show_query = st.toggle('Show rewritten query', False)
482
+ if show_query: # or reworded_query['changed']:
483
+ st.write(f"Rewritten query: {query}")
484
+
485
+ # creates container for LLM response to position it above search results
486
+ chat_container, response_box = [], st.empty()
487
+ # # RAG time !! execute chat call to LLM
488
+ if rag_it:
489
+ # st.subheader("Response from Impact Theory (context)")
490
+ # will appear under the answer, moved it into the response box
491
+
492
+ # generate LLM prompt
493
+ prompt = generate_prompt_series(query=query, results=valid_response)
494
+
495
+
496
+ GPTllm = GPT_Turbo(model=model_nameGPT,
497
+ api_key=st.secrets['secrets']['OPENAI_API_KEY'])
498
+ try:
499
+ # inserts chat stream from LLM
500
+ for resp in GPTllm.get_chat_completion(prompt=prompt,
501
+ temperature=llm_temperature,
502
+ max_tokens=350,
503
+ show_response=True,
504
+ stream=True):
505
+
506
+ with response_box:
507
+ content = resp.choices[0].delta.content
508
+ if content:
509
+ chat_container.append(content)
510
+ result = "".join(chat_container).strip()
511
+ response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
512
+ except BadRequestError as e:
513
+ logger.info('Making request with smaller context')
514
+
515
+ valid_response = validate_token_threshold(response,
516
+ question_answering_prompt_series,
517
+ query=query,
518
+ tokenizer=encoding,
519
+ token_threshold=3500,
520
+ verbose=True)
521
+ # if reranker is off, we may receive a LOT of responses
522
+ # so we must reduce the context size manually
523
+ if not rerank:
524
+ valid_response = valid_response[:reranker_topk]
525
+
526
+ prompt = generate_prompt_series(query=query, results=valid_response)
527
+ for resp in GPTllm.get_chat_completion(prompt=prompt,
528
+ temperature=llm_temperature,
529
+ max_tokens=350, # expand for more verbose answers
530
+ show_response=True,
531
+ stream=True):
532
+ try:
533
+ # inserts chat stream from LLM
534
+ with response_box:
535
+ content = resp.choice[0].delta.content
536
+ if content:
537
+ chat_container.append(content)
538
+ result = "".join(chat_container).strip()
539
+ response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
540
+ except Exception as e:
541
+ print(e)
542
+
543
+ st.markdown("----")
544
+ st.subheader("Search Results")
545
+
546
+ for i, hit in enumerate(valid_response):
547
+ col1, col2 = st.columns([7, 3], gap='large')
548
+ image = hit['thumbnail_url'] # get thumbnail_url
549
+ episode_url = hit['episode_url'] # get episode_url
550
+ title = hit["title"] # get title
551
+ show_length = hit["length"] # get length
552
+ time_string = str(timedelta(seconds=show_length)) # convert show_length to readable time string
553
+
554
+ with col1:
555
+ st.write(search_result(i=i,
556
+ url=episode_url,
557
+ guest=hit['guest'],
558
+ title=title,
559
+ content='',
560
+ length=time_string),
561
+ unsafe_allow_html=True)
562
+ st.write('\n\n')
563
+
564
+ with col2:
565
+ #st.write(f"<a href={episode_url} <img src={image} width='200'></a>",
566
+ # unsafe_allow_html=True)
567
+ #st.markdown(f"[![{title}]({image})]({episode_url})")
568
+ # st.markdown(f'<a href="{episode_url}">'
569
+ # f'<img src={image} '
570
+ # f'caption={title.split("|")[0]} width=200, use_column_width=False />'
571
+ # f'</a>',
572
+ # unsafe_allow_html=True)
573
+
574
+ st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
575
+ # let's use all width for the content
576
+ st.write(hit['content'])
577
+
578
+
579
+ def get_answer(query, valid_response, GPTllm):
580
+
581
+ # generate LLM prompt
582
+ prompt = generate_prompt_series(query=query,
583
+ results=valid_response)
584
+
585
+ return GPTllm.get_chat_completion(prompt=prompt,
586
+ system_message='answer this question based on the podcast material',
587
+ temperature=0,
588
+ max_tokens=500,
589
+ stream=False,
590
+ show_response=False)
591
+
592
+ def reword_query(query, guest, model_name='llama2-13b-chat', response_processing=True):
593
+ """ Asks LLM to rewrite the query when the guest name is missing.
594
+
595
+ Args:
596
+ query (str): user query
597
+ guest (str): guest name
598
+ model_name (str, optional): name of a LLM model to be used
599
+ """
600
+
601
+ # tags = {'llama2-13b-chat': {'start': '<s>', 'end': '</s>', 'instruction': '[INST]', 'system': '[SYS]'},
602
+ # 'gpt-3.5-turbo-0613': {'start': '<|startoftext|>', 'end': '', 'instruction': "```", 'system': ```}}
603
+
604
+ prompt_fields = {
605
+ "you_are":f"You are an expert in linguistics and semantics, analyzing the question asked by a user to a vector search system, \
606
+ and making sure that the question is well formulated and that the system can understand it.",
607
+
608
+ "your_task":f"Your task is to detect if the name of the guest ({guest}) is mentioned in the user's question, \
609
+ and if that is not the case, rewrite the question using the guest name, \
610
+ without changing the meaning of the question. \
611
+ Most of the time, the user will have used a pronoun to designate the guest, in which case, \
612
+ simply replace the pronoun with the guest name.",
613
+
614
+ "question":f"If the user mentions the guest name, ie {query}, just return his question as is. \
615
+ If the user does not mention the guest name, rewrite the question using the guest name.",
616
+
617
+ "final_instruction":f"Only regerate the requested rewritten question or the original, WITHOUT ANY COMMENT OR REPHRASING. \
618
+ Your answer must be as close as possible to the original question, \
619
+ and exactly identical, word for word, if the user mentions the guest name, i.e. {guest}.",
620
+ }
621
+
622
+ # prompt created by chatGPT :-)
623
+ # and Llama still outputs the original question and precedes the answer with 'rewritten question'
624
+ prompt_fields2 = {
625
+ "you_are": (
626
+ "You are an expert in linguistics and semantics. Your role is to analyze questions asked to a vector search system."
627
+ ),
628
+ "your_task": (
629
+ f"Detect if the guest's FULL name, {guest}, is mentioned in the user's question. "
630
+ "If not, rewrite the question by replacing pronouns or indirect references with the guest's name." \
631
+ "If yes, return the original question as is, without any change at all, not even punctuation,"
632
+ "except a question mark that you MUST add if it's missing."
633
+ ),
634
+ "question": (
635
+ f"Original question: '{query}'. "
636
+ "Rewrite this question to include the guest's FULL name if it's not already mentioned."
637
+ "The Only thing you can and MUST add is a question mark if it's missing."
638
+ ),
639
+ "final_instruction": (
640
+ "Create a rewritten question or keep the original question as is. "
641
+ "Do not include any labels, titles, or additional text before or after the question."
642
+ "The Only thing you can and MUST add is a question mark if it's missing."
643
+ "Return a json object, with the key 'original_question' for the original question, \
644
+ and 'rewritten_question' for the rewritten question \
645
+ and 'changed' being True if you changed the answer, otherwise False."
646
+ ),
647
+ }
648
+
649
+
650
+ if model_name == 'llama2-13b-chat':
651
+ # special tags are used:
652
+ # `<s>` - start prompt tag
653
+ # `[INST], [/INST]` - Opening and closing model instruction tags
654
+ # `<<<SYS>>>, <</SYS>>` - Opening and closing system prompt tags
655
+ llama_prompt = """
656
+ <s>[INST] <<SYS>>
657
+ {you_are}
658
+ <</SYS>>
659
+ {your_task}\n
660
+
661
+ ```
662
+ \n\n
663
+ Question: {question}\n
664
+ {final_instruction} [/INST]
665
+
666
+ Answer:
667
+ """
668
+ prompt = llama_prompt.format(**prompt_fields2)
669
+
670
+ hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
671
+ # hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN']
672
+
673
+ hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
674
+
675
+ headers = {"Authorization": f"Bearer {hf_token}",
676
+ "Content-Type": "application/json",}
677
+
678
+ json_body = {
679
+ "inputs": prompt,
680
+ "parameters": {"max_new_tokens":400,
681
+ "repetition_penalty": 1.0,
682
+ "temperature":0.01}
683
+ }
684
+
685
+ response = requests.request("POST", hf_endpoint, headers=headers, data=json.dumps(json_body))
686
+ response = json.loads(response.content.decode("utf-8"))
687
+ # ^ will not process the badly formatted generated text, so we do it ourselves
688
+
689
+ if isinstance(response, dict) and 'error' in response:
690
+ print("Found error")
691
+ print(response)
692
+ # return {'error': response['error'], 'rewritten_question': query, 'changed': False, 'status': 'error'}
693
+ # I test this here otherwise it gets in col 2 or 1, which are too
694
+ # if reworded_query['status'] == 'error':
695
+ # st.write(f"Error in LLM response: 'error':{reworded_query['error']}")
696
+ # st.write("The LLM could not connect to the server. Please try again later.")
697
+ # st.stop()
698
+ return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
699
+
700
+ if response_processing:
701
+ if isinstance(response, list) and isinstance(response[0], dict) and 'generated_text' in response[0]:
702
+ print("Found generated text")
703
+ response0 = response[0]['generated_text']
704
+ pattern = r'\"(\w+)\":\s*(\".*?\"|\w+)'
705
+
706
+ matches = re.findall(pattern, response0)
707
+ # let's build a dictionary
708
+ result = {key: json.loads(value) if value.startswith("\"") else value for key, value in matches}
709
+ return result | {'status': 'success'}
710
+ else:
711
+ print("Found no answer")
712
+ return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
713
+ # return {'original_question': query, 'rewritten_question': query, 'changed': False, 'status': 'no properly formatted answer' }
714
+ else:
715
+ return response
716
+ # return response
717
+ # assert 'error' not in response, f"Error in LLM response: {response['error']}"
718
+ # assert 'generated_text' in response[0], f"Error in LLM response: {response}, no 'generated_text' field"
719
+ # # let's extract the rewritten question
720
+ # return response[0]['generated_text'] .split("Rewritten question: '")[-1][:-1]
721
+
722
+ else:
723
+ # assume openai
724
+ model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
725
+ model_name = model_ids[1]
726
+ GPTllm = GPT_Turbo(model=model_name,
727
+ api_key=st.secrets['secrets']['OPENAI_API_KEY'])
728
+
729
+ openai_prompt = """
730
+ {your_task}\n
731
+ ```
732
+ \n\n
733
+ Question: {question}\n
734
+ {final_instruction}
735
+
736
+ Answer:
737
+ """
738
+ prompt = openai_prompt.format(**prompt_fields)
739
+
740
+ try:
741
+ resp = GPTllm.get_chat_completion(prompt=openai_prompt,
742
+ system_message=prompt_fields['you_are'],
743
+ temperature=0.01,
744
+ max_tokens=1500, # it's a question...
745
+ show_response=True,
746
+ stream=False)
747
+ return {'rewritten_question': resp.choices[0].delta.content,
748
+ 'changed': True, 'status': 'success'}
749
+ except Exception:
750
+ return {'rewritten_question': query, 'changed': False, 'status': 'not success'}
751
+
752
+
753
+ if __name__ == '__main__':
754
+ main()
755
+ # %%
backend.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+
3
+ from typing import List, Dict, Tuple, Union, Callable
4
+ # from preprocessing import FileIO
5
+
6
+ # assets = modal.Mount.from_local_dir(
7
+ # "./data",
8
+ # # condition=lambda pth: not ".venv" in pth,
9
+ # remote_path="./data",
10
+ # )
11
+
12
+
13
+ stub = modal.Stub("vector-search-project")
14
+ vector_search = modal.Image.debian_slim().pip_install(
15
+ "sentence_transformers==2.2.2", "llama_index==0.9.6.post1", "angle_emb==0.1.5"
16
+ )
17
+
18
+ stub.volume = modal.Volume.new()
19
+
20
+
21
+ @stub.function(image=vector_search,
22
+ gpu="A100",
23
+ timeout=600,
24
+ volumes={"/root/models": stub.volume}
25
+ # secrets are available in the environment with os.environ["SECRET_NAME"]
26
+ # secret=modal.Secret.from_name("my-huggingface-secret")
27
+ )
28
+ def encode_content_splits(content_splits,
29
+ model=None, # path or name of model
30
+ **kwargs
31
+ ):
32
+ """ kwargs provided in case encode method has extra arguments """
33
+ from sentence_transformers import SentenceTransformer
34
+
35
+ import os, time
36
+ models_list = os.listdir('/root/models')
37
+ print("Models:", models_list)
38
+
39
+ if isinstance(model, str) and model[-1] == "/":
40
+ model = model[:-1]
41
+
42
+ if isinstance(model, str):
43
+ model = model.split('/')[-1]
44
+
45
+ if isinstance(model, str) and model in models_list:
46
+
47
+ if "UAE-Large-V1-300" in model:
48
+ print("Loading finetuned UAE-Large-V1-300 model from Modal Volume")
49
+
50
+ from angle_emb import AnglE
51
+ model = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1',
52
+ pretrained_model_path=os.path.join('/root/models', model),
53
+ pooling_strategy='cls').cuda()
54
+ kwargs['to_numpy'] = True
55
+
56
+ # this model doesn't accept list of lists
57
+ if isinstance(content_splits[0], list):
58
+ content_splits = [chunk for episode in content_splits for chunk in episode]
59
+
60
+ else:
61
+ print(f"Loading model {model} from Modal volume")
62
+ model = SentenceTransformer(os.path.join('/root/models', model))
63
+
64
+ elif isinstance(model, str):
65
+ if model in models_list:
66
+ print(f"Loading model {model} from Modal volume")
67
+ model = SentenceTransformer(os.path.join('/root/models', model))
68
+ else:
69
+ print(f"Model {model} not found in Modal volume, loading from HuggingFace")
70
+ model = SentenceTransformer(model)
71
+
72
+ else:
73
+ print(f"Using model provided as argument")
74
+ if 'save' in kwargs:
75
+ if isinstance(kwargs['save'], str) and kwargs['save'][-1] == '/':
76
+ kwargs['save'] = kwargs['save'][:-1]
77
+ kwargs['save'] = kwargs['save'].split('/')[-1]
78
+ fname = os.path.join('/root/models', kwargs['save'])
79
+ print(f"Saving model in {fname}")
80
+ # model.save(fname)
81
+ print(f"Model saved in {fname}")
82
+ kwargs.pop('save')
83
+
84
+ print("Starting encoding")
85
+ start = time.perf_counter()
86
+
87
+ emb = [list(zip(episode, model.encode(episode, **kwargs))) for episode in content_splits]
88
+ end = time.perf_counter() - start
89
+ print(f"GPU processing lasted {end:.2f} seconds")
90
+ print("Encoding finished")
91
+
92
+ return emb
93
+
94
+
95
+ @stub.function(image=vector_search, gpu="A100", timeout=120,
96
+ mounts=[modal.Mount.from_local_dir("./data",
97
+ remote_path="/root/data",
98
+ condition=lambda pth: ".json" in pth)],
99
+ volumes={"/root/models": stub.volume}
100
+ )
101
+ def finetune(training_path='./data/training_data_300.json',
102
+ valid_path='./data/validation_data_100.json',
103
+ model_id=None):
104
+
105
+ import os
106
+ print("Data:", os.listdir('/root/data'))
107
+ print("Models:", os.listdir('/root/models'))
108
+
109
+ if model_id is None:
110
+ print("No model ID provided")
111
+ return None
112
+ elif isinstance(model_id, str) and model_id[-1] == "/":
113
+ model_id = model_id[:-1]
114
+
115
+
116
+ from llama_index.finetuning import EmbeddingQAFinetuneDataset
117
+
118
+ training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
119
+ valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
120
+ print("Datasets loaded")
121
+
122
+ num_training_examples = len(training_set.queries)
123
+ print(f"Training examples: {num_training_examples}")
124
+
125
+ from llama_index.finetuning import SentenceTransformersFinetuneEngine
126
+
127
+ print(f"Model Name is {model_id}")
128
+ model_ext = model_id.split('/')[1]
129
+
130
+ ft_model_name = f'finetuned-{model_ext}-{num_training_examples}'
131
+ model_outpath = os.path.join("/root/models", ft_model_name)
132
+
133
+ print(f'Model ID: {model_id}')
134
+ print(f'Model Outpath: {model_outpath}')
135
+
136
+ finetune_engine = SentenceTransformersFinetuneEngine(
137
+ training_set,
138
+ batch_size=32,
139
+ model_id=model_id,
140
+ model_output_path=model_outpath,
141
+ val_dataset=valid_set,
142
+ epochs=10
143
+ )
144
+ import io, os, zipfile, glob, time
145
+ try:
146
+ start = time.perf_counter()
147
+ finetune_engine.finetune()
148
+ end = time.perf_counter() - start
149
+ print(f"GPU processing lasted {end:.2f} seconds")
150
+
151
+ print(os.listdir('/root/models'))
152
+ stub.volume.commit() # Persist changes, ie the finetumed model
153
+
154
+ # TODO SHARE THE MODEL ON HUGGINGFACE
155
+ # https://huggingface.co/docs/transformers/v4.15.0/model_sharing
156
+
157
+ folder_to_zip = model_outpath
158
+ # Zip the contents of the folder at 'folder_path' and return a BytesIO object.
159
+ bytes_buffer = io.BytesIO()
160
+
161
+ with zipfile.ZipFile(bytes_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
162
+ for file_path in glob.glob(folder_to_zip + "/**", recursive=True):
163
+ print(f"Processed file {file_path}")
164
+ zip_file.write(file_path, os.path.relpath(file_path, start=folder_to_zip))
165
+
166
+ # Move the pointer to the start of the BytesIO buffer before returning
167
+ bytes_buffer.seek(0)
168
+ # You can now return this zipped_folder object, write it to a file, send it over a network, etc.
169
+ # Replace with the path to the folder you want to zip
170
+ zippedio = bytes_buffer
171
+
172
+ return zippedio
173
+ except:
174
+ return "Finetuning failed"
175
+
176
+
177
+ @stub.local_entrypoint()
178
+ def test_method(content_splits=[["a"]]):
179
+ output = encode_content_splits.remote(content_splits)
180
+ return output
181
+
182
+ # deploy it with
183
+ # modal token set --token-id ak-xxxxxx --token-secret as-xxxxx # given when we create a new token
184
+ # modal deploy podcast/1/backend.py
185
+ # View Deployment: https://modal.com/apps/jpbianchi/falcon_hackaton-project <<< use this project name
class_templates.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ impact_theory_class_properties = [
2
+ {'name': 'title',
3
+ 'dataType': ['text'],
4
+ 'indexFilterable': True,
5
+ 'indexSearchable': True},
6
+ {'name': 'video_id',
7
+ 'dataType': ['text'],
8
+ 'indexFilterable': True,
9
+ 'indexSearchable': False},
10
+ {'name': 'length',
11
+ 'dataType': ['int'],
12
+ 'indexFilterable': True,
13
+ 'indexSearchable': False},
14
+ {'name': 'thumbnail_url',
15
+ 'dataType': ['text'],
16
+ 'indexFilterable': False,
17
+ 'indexSearchable': False},
18
+ {'name': 'views',
19
+ 'dataType': ['int'],
20
+ 'indexFilterable': True,
21
+ 'indexSearchable': False},
22
+ {'name': 'episode_url',
23
+ 'dataType': ['text'],
24
+ 'indexFilterable': False,
25
+ 'indexSearchable': False},
26
+ {'name': 'doc_id',
27
+ 'dataType': ['text'],
28
+ 'indexFilterable': True,
29
+ 'indexSearchable': False},
30
+ {'name': 'guest',
31
+ 'dataType': ['text'],
32
+ 'indexFilterable': True,
33
+ 'indexSearchable': True},
34
+ {'name': 'summary',
35
+ 'dataType': ['text'],
36
+ 'indexFilterable': False,
37
+ 'indexSearchable': True},
38
+ {'name': 'content',
39
+ 'dataType': ['text'],
40
+ 'indexFilterable': False,
41
+ 'indexSearchable': True},
42
+ ]
43
+ # {'name': 'publish_date',
44
+ # 'dataType': ['date'],
45
+ # 'indexFilterable': True,
46
+ # 'indexSearchable': False},
finetune_backend.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os, time, io, zipfile
3
+ from preprocessing import FileIO
4
+ import shutil
5
+ import modal
6
+ from llama_index.finetuning import EmbeddingQAFinetuneDataset
7
+
8
+ from dotenv import load_dotenv, find_dotenv
9
+ env = load_dotenv(find_dotenv('env'), override=True)
10
+
11
+ #%%
12
+ training_path = 'data/training_data_300.json'
13
+ valid_path = 'data/validation_data_100.json'
14
+
15
+ training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
16
+ valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
17
+
18
+ def finetune(model='all-mpnet-base-v2', savemodel=False, outpath='.'):
19
+ """ Finetunes a model on Modal GPU A100.
20
+ The model is saved in /root/models on a Modal volume
21
+ and can be stored locally.
22
+
23
+ Args:
24
+ model (str): the Sentence Transformer model name
25
+ savemodel (bool, optional): whether to save the model or not.
26
+
27
+ Returns:
28
+ path of the saved model (when saved)
29
+ """
30
+ f = modal.Function.lookup("vector-search-project", "finetune")
31
+ model = model.replace('/','')
32
+
33
+ if 'sentence-transformers' not in model:
34
+ model = f"sentence-transformers/{model}"
35
+
36
+ fullpath = os.path.join(outpath, f"finetuned-{model}-300")
37
+
38
+ if os.path.exists(fullpath):
39
+ msg = "Model already exists!"
40
+ print(msg)
41
+ return msg
42
+
43
+ start = time.perf_counter()
44
+ finetuned_model = f.remote(training_path, valid_path, model_id=model)
45
+
46
+ end = time.perf_counter() - start
47
+ print(f"Finetuning with GPU lasted {end:.2f} seconds")
48
+
49
+ if savemodel:
50
+
51
+ with open(fullpath, 'wb') as file:
52
+ # Write the contents of the BytesIO object to a new file
53
+ file.write(finetuned_model.getbuffer())
54
+ print(f"Model saved in {fullpath}")
55
+ return fullpath
helpers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Dict, Any
2
+ import time
3
+ from tqdm.notebook import tqdm
4
+ from rich import print
5
+
6
+ from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params
7
+ from llama_index.finetuning import EmbeddingQAFinetuneDataset
8
+ from weaviate_interface import WeaviateClient
9
+
10
+
11
+ def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset,
12
+ class_name: str,
13
+ retriever: WeaviateClient,
14
+ retrieve_limit: int=5,
15
+ chunk_size: int=256,
16
+ hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
17
+ display_properties: List[str]=['doc_id', 'guest', 'content'],
18
+ dir_outpath: str='./eval_results',
19
+ include_miss_info: bool=False,
20
+ user_def_params: Dict[str,Any]=None
21
+ ) -> Dict[str, str|int|float]:
22
+ '''
23
+ Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector
24
+ hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses
25
+ and their associated queries that did not return a hit, for deeper analysis. Text file with results output
26
+ is automatically saved in the dir_outpath directory.
27
+
28
+ Args:
29
+ -----
30
+ dataset: EmbeddingQAFinetuneDataset
31
+ Dataset to be used for evaluation
32
+ class_name: str
33
+ Name of Class on Weaviate host to be used for retrieval
34
+ retriever: WeaviateClient
35
+ WeaviateClient object to be used for retrieval
36
+ retrieve_limit: int=5
37
+ Number of documents to retrieve from Weaviate host
38
+ chunk_size: int=256
39
+ Number of tokens used to chunk text. This value is purely for results
40
+ recording purposes and does not affect results.
41
+ display_properties: List[str]=['doc_id', 'content']
42
+ List of properties to be returned from Weaviate host for display in response
43
+ dir_outpath: str='./eval_results'
44
+ Directory path for saving results. Directory will be created if it does not
45
+ already exist.
46
+ include_miss_info: bool=False
47
+ Option to include queries and their associated kw and vector response values
48
+ for queries that are "total misses"
49
+ user_def_params : dict=None
50
+ Option for user to pass in a dictionary of user-defined parameters and their values.
51
+ '''
52
+
53
+ results_dict = {'n':retrieve_limit,
54
+ 'Retriever': retriever.model_name_or_path,
55
+ 'chunk_size': chunk_size,
56
+ 'kw_hit_rate': 0,
57
+ 'kw_mrr': 0,
58
+ 'vector_hit_rate': 0,
59
+ 'vector_mrr': 0,
60
+ 'total_misses': 0,
61
+ 'total_questions':0
62
+ }
63
+ #add hnsw configs and user defined params (if any)
64
+ results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
65
+
66
+ start = time.perf_counter()
67
+ miss_info = []
68
+ for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
69
+ results_dict['total_questions'] += 1
70
+ hit = False
71
+ #make Keyword, Vector, and Hybrid calls to Weaviate host
72
+ try:
73
+ kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
74
+ vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
75
+
76
+ #collect doc_ids and position of doc_ids to check for document matches
77
+ kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)}
78
+ vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)}
79
+
80
+ #extract doc_id for scoring purposes
81
+ doc_id = dataset.relevant_docs[query_id][0]
82
+
83
+ #increment hit_rate counters and mrr scores
84
+ if doc_id in kw_doc_ids:
85
+ results_dict['kw_hit_rate'] += 1
86
+ results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
87
+ hit = True
88
+ if doc_id in vector_doc_ids:
89
+ results_dict['vector_hit_rate'] += 1
90
+ results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
91
+ hit = True
92
+
93
+ # if no hits, let's capture that
94
+ if not hit:
95
+ results_dict['total_misses'] += 1
96
+ miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response})
97
+ except Exception as e:
98
+ print(e)
99
+ continue
100
+
101
+
102
+ #use raw counts to calculate final scores
103
+ calc_hit_rate_scores(results_dict)
104
+ calc_mrr_scores(results_dict)
105
+
106
+ end = time.perf_counter() - start
107
+ print(f'Total Processing Time: {round(end/60, 2)} minutes')
108
+ record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
109
+
110
+ if include_miss_info:
111
+ return results_dict, miss_info
112
+ return results_dict
llama_test.ipynb ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Note: you may need to restart the kernel to use updated packages.\n",
13
+ "Note: you may need to restart the kernel to use updated packages.\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "%pip install huggingface_hub --q\n",
19
+ "%pip install ipywidgets --q"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 3,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "from transformers.pipelines.text_generation import TextGenerationPipeline\n",
29
+ "from transformers import AutoConfig\n",
30
+ "import transformers"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "data": {
40
+ "application/vnd.jupyter.widget-view+json": {
41
+ "model_id": "f9c842f1bd7146e5a4e4d517450531ee",
42
+ "version_major": 2,
43
+ "version_minor": 0
44
+ },
45
+ "text/plain": [
46
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
47
+ ]
48
+ },
49
+ "metadata": {},
50
+ "output_type": "display_data"
51
+ }
52
+ ],
53
+ "source": [
54
+ "from huggingface_hub import notebook_login\n",
55
+ "notebook_login() #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"
56
+ ]
57
+ }
58
+ ],
59
+ "metadata": {
60
+ "kernelspec": {
61
+ "display_name": "venv",
62
+ "language": "python",
63
+ "name": "python3"
64
+ },
65
+ "language_info": {
66
+ "codemirror_mode": {
67
+ "name": "ipython",
68
+ "version": 3
69
+ },
70
+ "file_extension": ".py",
71
+ "mimetype": "text/x-python",
72
+ "name": "python",
73
+ "nbconvert_exporter": "python",
74
+ "pygments_lexer": "ipython3",
75
+ "version": "3.11.5"
76
+ }
77
+ },
78
+ "nbformat": 4,
79
+ "nbformat_minor": 2
80
+ }
openai_interface.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+ from typing import List, Any, Tuple
4
+ from tqdm import tqdm
5
+ import streamlit as st
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+
8
+ from dotenv import load_dotenv, find_dotenv
9
+ load_dotenv(find_dotenv('env'), override=True)
10
+
11
+ try:
12
+ api_key = st.secrets['secrets']['OPENAI_API_KEY']
13
+ except:
14
+ api_key = os.environ['OPENAI_API_KEY']
15
+ class GPT_Turbo:
16
+
17
+ def __init__(self, model: str="gpt-3.5-turbo-0613", api_key: str=api_key):
18
+ self.model = model
19
+ self.client = OpenAI(api_key=api_key)
20
+
21
+ def get_chat_completion(self,
22
+ prompt: str,
23
+ system_message: str='You are a helpful assistant.',
24
+ temperature: int=0,
25
+ max_tokens: int=500,
26
+ stream: bool=False,
27
+ show_response: bool=False
28
+ ) -> str:
29
+ messages = [
30
+ {'role': 'system', 'content': system_message},
31
+ {'role': 'assistant', 'content': prompt}
32
+ ]
33
+
34
+ response = self.client.chat.completions.create( model=self.model,
35
+ messages=messages,
36
+ temperature=temperature,
37
+ max_tokens=max_tokens,
38
+ stream=stream)
39
+ if show_response:
40
+ return response
41
+ return response.choices[0].message.content
42
+
43
+ def multi_thread_request(self,
44
+ filepath: str,
45
+ prompt: str,
46
+ content: List[str],
47
+ temperature: int=0
48
+ ) -> List[Any]:
49
+
50
+ data = []
51
+ with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
52
+ futures = [exec.submit(self.get_completion_from_messages, [{'role': 'user','content': f'{prompt} ```{c}```'}], temperature, 500, False) for c in content]
53
+ with open(filepath, 'a') as f:
54
+ for future in as_completed(futures):
55
+ result = future.result()
56
+ if len(data) % 10 == 0:
57
+ print(f'{len(data)} of {len(content)} completed.')
58
+ if result:
59
+ data.append(result)
60
+ self.write_to_file(file_handle=f, data=result)
61
+ return [res for res in data if res]
62
+
63
+ def generate_question_context_pairs(self,
64
+ context_tuple: Tuple[str, str],
65
+ num_questions_per_chunk: int=2,
66
+ max_words_per_question: int=10
67
+ ) -> List[str]:
68
+
69
+ doc_id, context = context_tuple
70
+ prompt = f'Context information is included below enclosed in triple backticks. Given the context information and not prior knowledge, generate questions based on the below query.\n\nYou are an end user querying for information about your favorite podcast. \
71
+ Your task is to setup {num_questions_per_chunk} questions that can be answered using only the given context. The questions should be diverse in nature across the document and be no longer than {max_words_per_question} words. \
72
+ Restrict the questions to the context information provided.\n\
73
+ ```{context}```\n\n'
74
+
75
+ response = self.get_completion_from_messages(prompt=prompt, temperature=0, max_tokens=500, show_response=True)
76
+ questions = response.choices[0].message["content"]
77
+ return (doc_id, questions)
78
+
79
+ def batch_generate_question_context_pairs(self,
80
+ context_tuple_list: List[Tuple[str, str]],
81
+ num_questions_per_chunk: int=2,
82
+ max_words_per_question: int=10
83
+ ) -> List[Tuple[str, str]]:
84
+ data = []
85
+ progress = tqdm(unit="Generated Questions", total=len(context_tuple_list))
86
+ with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
87
+ futures = [exec.submit(self.generate_question_context_pairs, context_tuple, num_questions_per_chunk, max_words_per_question) for context_tuple in context_tuple_list]
88
+ for future in as_completed(futures):
89
+ result = future.result()
90
+ if result:
91
+ data.append(result)
92
+ progress.update(1)
93
+ return data
94
+
95
+ def get_embedding(self):
96
+ pass
97
+
98
+ def write_to_file(self, file_handle, data: str) -> None:
99
+ file_handle.write(data)
100
+ file_handle.write('\n')
preprocessing.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from typing import List, Union, Dict
5
+ from loguru import logger
6
+ import pandas as pd
7
+ import pathlib
8
+
9
+
10
+ ## Set of helper functions that support data preprocessing
11
+ class FileIO:
12
+ '''
13
+ Convenience class for saving and loading data in parquet and
14
+ json formats to/from disk.
15
+ '''
16
+
17
+ def save_as_parquet(self,
18
+ file_path: str,
19
+ data: Union[List[dict], pd.DataFrame],
20
+ overwrite: bool=False) -> None:
21
+ '''
22
+ Saves DataFrame to disk as a parquet file. Removes the index.
23
+
24
+ Args:
25
+ -----
26
+ file_path : str
27
+ Output path to save file, if not included "parquet" will be appended
28
+ as file extension.
29
+ data : Union[List[dict], pd.DataFrame]
30
+ Data to save as parquet file. If data is a list of dicts, it will be
31
+ converted to a DataFrame before saving.
32
+ overwrite : bool
33
+ Overwrite existing file if True, otherwise raise FileExistsError.
34
+ '''
35
+ if isinstance(data, list):
36
+ data = self._convert_toDataFrame(data)
37
+ if not file_path.endswith('parquet'):
38
+ file_path = self._rename_file_extension(file_path, 'parquet')
39
+ self._check_file_path(file_path, overwrite=overwrite)
40
+ data.to_parquet(file_path, index=False)
41
+ logger.info(f'DataFrame saved as parquet file here: {file_path}')
42
+
43
+ def _convert_toDataFrame(self, data: List[dict]) -> pd.DataFrame:
44
+ return pd.DataFrame().from_dict(data)
45
+
46
+ def _rename_file_extension(self, file_path: str, extension: str):
47
+ '''
48
+ Renames file with appropriate extension if file_path
49
+ does not already have correct extension.
50
+ '''
51
+ prefix = os.path.splitext(file_path)[0]
52
+ file_path = prefix + '.' + extension
53
+ return file_path
54
+
55
+ def _check_file_path(self, file_path: str, overwrite: bool) -> None:
56
+ '''
57
+ Checks for existence of file and overwrite permissions.
58
+ '''
59
+ if os.path.exists(file_path) and overwrite == False:
60
+ raise FileExistsError(f'File by name {file_path} already exists, try using another file name or set overwrite to True.')
61
+ elif os.path.exists(file_path):
62
+ os.remove(file_path)
63
+ else:
64
+ file_name = os.path.basename(file_path)
65
+ dir_structure = file_path.replace(file_name, '')
66
+ pathlib.Path(dir_structure).mkdir(parents=True, exist_ok=True)
67
+
68
+ def load_parquet(self, file_path: str, verbose: bool=True) -> List[dict]:
69
+ '''
70
+ Loads parquet from disk, converts to pandas DataFrame as intermediate
71
+ step and outputs a list of dicts (docs).
72
+ '''
73
+ df = pd.read_parquet(file_path)
74
+ vector_labels = ['content_vector', 'image_vector', 'content_embedding']
75
+ for label in vector_labels:
76
+ if label in df.columns:
77
+ df[label] = df[label].apply(lambda x: x.tolist())
78
+ if verbose:
79
+ memory_usage = round(df.memory_usage().sum()/(1024*1024),2)
80
+ print(f'Shape of data: {df.values.shape}')
81
+ print(f'Memory Usage: {memory_usage}+ MB')
82
+ list_of_dicts = df.to_dict('records')
83
+ return list_of_dicts
84
+
85
+ def load_json(self, file_path: str):
86
+ '''
87
+ Loads json file from disk.
88
+ '''
89
+ with open(file_path) as f:
90
+ data = json.load(f)
91
+ return data
92
+
93
+ def save_as_json(self,
94
+ file_path: str,
95
+ data: Union[List[dict], dict],
96
+ indent: int=4,
97
+ overwrite: bool=False
98
+ ) -> None:
99
+ '''
100
+ Saves data to disk as a json file. Data can be a list of dicts or a single dict.
101
+ '''
102
+ if not file_path.endswith('json'):
103
+ file_path = self._rename_file_extension(file_path, 'json')
104
+ self._check_file_path(file_path, overwrite=overwrite)
105
+ with open(file_path, 'w') as f:
106
+ json.dump(data, f, indent=indent)
107
+ logger.info(f'Data saved as json file here: {file_path}')
108
+
109
+ class Utilities:
110
+
111
+ def create_video_url(self, video_id: str, playlist_id: str):
112
+ '''
113
+ Creates a hyperlink to a video episode given a video_id and playlist_id.
114
+
115
+ Args:
116
+ -----
117
+ video_id : str
118
+ Video id of the episode from YouTube
119
+ playlist_id : str
120
+ Playlist id of the episode from YouTube
121
+ '''
122
+ return f'https://www.youtube.com/watch?v={video_id}&list={playlist_id}'
123
+
prompt_templates.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ question_answering_system = '''
2
+ You are the host of the show Impact Theory, and your name is Tom Bilyeu. The description of your show is as follows:
3
+ If you’re looking to thrive in uncertain times, achieve unprecedented goals, and improve the most meaningful aspects of your life, then Impact Theory is the show for you. Hosted by Tom Bilyeu, a voracious learner and hyper-successful entrepreneur, the show investigates and analyzes the most useful topics with the world’s most sought-after guests.
4
+ Bilyeu attacks each episode with a clear desire to further evolve the holistic skillset that allowed him to co-found the billion-dollar company Quest Nutrition, generate over half a billion organic views on his content, build a thriving marriage of over 20 years, and quantifiably improve the lives of over 10,000 people through his school, Impact Theory University.
5
+ Bilyeu’s insatiable hunger for knowledge gives the show urgency, relevance, and depth while leaving listeners with the knowledge, tools, and empowerment to take control of their lives and develop true personal power.
6
+ '''
7
+
8
+ question_answering_prompt_single = '''
9
+ Use the below context enclosed in triple back ticks to answer the question. If the context does not provide enough information to answer the question, then use any knowledge you have to answer the question.\n
10
+ ```{context}```\n
11
+ Question:\n
12
+ {question}.\n
13
+ Answer:
14
+ '''
15
+
16
+ question_answering_prompt_series = '''
17
+ Your task is to synthesize and reason over a series of transcripts of an interview between Tom Bilyeu and his guest(s).
18
+ After your synthesis, use the series of transcripts to answer the below question. The series will be in the following format:\n
19
+ ```
20
+ Show Summary: <summary>
21
+ Show Guest: <guest>
22
+ Transcript: <transcript>
23
+ ```\n\n
24
+ Start Series:
25
+ ```
26
+ {series}
27
+ ```
28
+ Question:\n
29
+ {question}\n
30
+ Answer the question and provide reasoning if necessary to explain the answer.\n
31
+ If the context does not provide enough information to answer the question, then \n
32
+ state that you cannot answer the question with the provided context.\n
33
+
34
+ Answer:
35
+ '''
36
+
37
+ context_block = '''
38
+ Show Summary: {summary}
39
+ Show Guest: {guest}
40
+ Transcript: {transcript}
41
+ '''
42
+
43
+ qa_generation_prompt = '''
44
+ Impact Theory episode summary and episode guest are below:
45
+
46
+ ---------------------
47
+ Summary: {summary}
48
+ ---------------------
49
+ Guest: {guest}
50
+ ---------------------
51
+ Given the Summary and Guest of the episode as context \
52
+ use the following randomly selected transcript section \
53
+ of the episode and not prior knowledge, generate questions that can \
54
+ be answered by the transcript section:
55
+
56
+ ---------------------
57
+ Transcript: {transcript}
58
+ ---------------------
59
+
60
+ Your task is to create {num_questions_per_chunk} questions that can \
61
+ only be answered given the previous context and transcript details. \
62
+ The question should randomly start with How, Why, or What.
63
+ '''
prompt_templates_luis.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ question_answering_system = '''
2
+ You are the host of the show Impact Theory, and your name is Tom Bilyeu. The description of your show is as follows:
3
+ If you’re looking to thrive in uncertain times, achieve unprecedented goals, and improve the most meaningful aspects of your life, then Impact Theory is the show for you. Hosted by Tom Bilyeu, a voracious learner and hyper-successful entrepreneur, the show investigates and analyzes the most useful topics with the world’s most sought-after guests.
4
+ Bilyeu attacks each episode with a clear desire to further evolve the holistic skillset that allowed him to co-found the billion-dollar company Quest Nutrition, generate over half a billion organic views on his content, build a thriving marriage of over 20 years, and quantifiably improve the lives of over 10,000 people through his school, Impact Theory University.
5
+ Bilyeu’s insatiable hunger for knowledge gives the show urgency, relevance, and depth while leaving listeners with the knowledge, tools, and empowerment to take control of their lives and develop true personal power.
6
+ '''
7
+
8
+ question_answering_prompt_single = '''
9
+ Use the below context enclosed in triple back ticks to answer the question. If the context does not provide enough information to answer the question, then use any knowledge you have to answer the question.\n
10
+ ```{context}```\n
11
+ Question:\n
12
+ {question}.\n
13
+ Answer:
14
+ '''
15
+
16
+ question_answering_prompt_series = '''
17
+ Your task is to synthesize and reason over a series of transcripts of an interview between Tom Bilyeu and his guest(s).
18
+ After your synthesis, use the series of transcripts to answer the below question. The series will be in the following format:\n
19
+ ```
20
+ Show Summary: <summary>
21
+ Show Guest: <guest>
22
+ Transcript: <transcript>
23
+ ```\n\n
24
+ Start Series:
25
+ ```
26
+ {series}
27
+ ```
28
+ Question:\n
29
+ {question}\n
30
+ Answer the question and provide reasoning if necessary to explain the answer.\n
31
+ If the context does not provide enough information to answer the question, then \n
32
+ state that you cannot answer the question with the provided context.\n
33
+
34
+ Answer:
35
+ '''
36
+
37
+ context_block = '''
38
+ Show Summary: {summary}
39
+ Show Guest: {guest}
40
+ Transcript: {transcript}
41
+ '''
42
+
43
+ qa_generation_prompt = '''
44
+ Impact Theory episode summary and episode guest are below:
45
+
46
+ ---------------------
47
+ Summary: {summary}
48
+ ---------------------
49
+ Guest: {guest}
50
+ ---------------------
51
+ Given the Summary and Guest of the episode as context \
52
+ use the following randomly selected transcript section \
53
+ of the episode and not prior knowledge, generate questions that can \
54
+ be answered by the transcript section:
55
+
56
+ ---------------------
57
+ Transcript: {transcript}
58
+ ---------------------
59
+
60
+ Your task is to create {num_questions_per_chunk} questions that can \
61
+ only be answered given the previous context and transcript details. \
62
+ The question should randomly start with How, Why, or What.
63
+ '''
readme2.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to Vector Search Applications with LLMs
2
+ This is the course repository for Vector Search Applications with LLMs taught by [Chris Sanchez](https://www.linkedin.com/in/excellenceisahabit/) with assistance from [Matias Weber](https://www.linkedin.com/in/matiasweber/).
3
+ The course is desgined to teach search and discovery industry best practices culminating in a demo Retrieval Augmented Generation (RAG) application. Along the way students will learn all of the components of a RAG system to include data preprocessing, embedding creation, vector database selection, indexing, retrieval systems, reranking, retrieval evaluation, question answering through an LLM and UI implementation through Streamlit.
4
+
5
+ # Prerequisites - Technical Experience
6
+ Students are expected to have the following technical skills prior to enrolling. Students who do not meet these prerequisites will likely have an overly challenging learning experience:
7
+ - Minimum of 1-year experience coding in Python. Skillsets should include programming using OOP, dictionary and list comprehensions, lambda functions, setting up virtual environments, comfortability with git version control.
8
+ - Professional or academic experience working with search engines.
9
+ - Ability to comfortably navigate the command line to include familiarity with docker.
10
+ - Nice to have but not strictly required:
11
+ - experience fine-tuning a ML model
12
+ - familiarity with the Streamlit API
13
+ - familiarity with making inference calls to a Generative LLM (OpenAI or Llama-2)
14
+ # Prerequisites - Administrative
15
+ 1. Students will need access to their own compute environment, whether locally or remote. There are no hard requirements for RAM or CPU processing power, but in general the more punch the better.
16
+ 2. Students will need accounts with the following organizations:
17
+ - Either an [OpenAI](https://openai.com) account **(RECOMMENDED)** or a [HuggingFace](https://huggingface.co/join) account. Students have the option of either using a paid LLM service (OpenAI) or using the open source `meta-llama/Llama-2-7b-chat-hf` model. Students choosing the latter option will first need to [register with Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) to request access to the Llama-2 model.
18
+ - An account with [weaviate.io](https://weaviate.io). The current iteration of this course will use Weaviate as a sparse and dense vector database. Weaviate offers free cloud instance cluster resources for 21 days (as of November 2023). **Students are advised to NOT CREATE** a Weaviate cloud cluster until the course officially starts.
19
+ - A standard [Github](https://github.com/) account in order to fork this repo, clone a copy, and submit commits to the fork as needed throughout the course.
20
+
21
+ # Setup
22
+ 1. Fork this course repo (see upper right hand corner of the repo web page).
23
+ <img src="assets/forkbutton.png" alt="fork button" width="300" height="auto">
24
+ 3. Clone a copy of the forked repo into the dev environment of your choice. Navigate into the cloned `vectorsearch-applications` directory.
25
+ 4. Create a python virtual environment using your library of choice. Here's an example using [`conda`](https://docs.conda.io/projects/miniconda/en/latest/):
26
+ ```
27
+ conda create --name impactenv -y python=3.10
28
+ ```
29
+ 4. Once the environment is created, activate the environment and install dependencies.
30
+ ```
31
+ conda activate impactenv
32
+
33
+ pip install -r requirements.txt
34
+ ```
35
+ 5. Last but not least create a `.env` text file in your cloned repo. At a minimum, add the following environment variables:
36
+ ```
37
+ OPENAI_API_KEY= "your OpenAI account API Key"
38
+ HF_TOKEN= "your HuggingFace account token" <--- Optional: not required if using OpenAI
39
+ WEAVIATE_API_KEY= "your Weaviate cluster API Key" <--- you will get this on Day One of the course
40
+ WEAVIATE_ENDPOINT= "your Weaviate cluster endpoint" <--- you will get this on Day One of the course
41
+ ```
42
+ 6. If you've made it this far, you are ready to start the course. Enjoy the process!
43
+ <img src="assets/getsome.jpg" alt="jocko" width="500" height="auto">
requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ beautifulsoup4==4.12.2
2
+ datasets==2.14.3
3
+ huggingface-hub==0.16.4
4
+ ipython==8.14.0
5
+ ipywidgets==8.1.1
6
+ jedi==0.19.0
7
+ jupyter-events==0.7.0
8
+ jupyter-lsp==2.2.0
9
+ jupyter_client==8.3.0
10
+ jupyter_core==5.3.1
11
+ jupyter_server==2.7.0
12
+ jupyter_server_terminals==0.4.4
13
+ jupyterlab==4.0.4
14
+ jupyterlab-pygments==0.2.2
15
+ jupyterlab-widgets==3.0.9
16
+ jupyterlab_server==2.24.0
17
+ langchain==0.0.310
18
+ langcodes==3.3.0
19
+ langsmith==0.0.43
20
+ llama-hub==0.0.47post1
21
+ llama-index==0.9.6.post1
22
+ loguru==0.7.0
23
+ matplotlib==3.7.2
24
+ matplotlib-inline==0.1.6
25
+ numpy==1.24.4
26
+ openai==1.3.5
27
+ pandas==2.0.3
28
+ protobuf==4.23.4
29
+ pyarrow==12.0.1
30
+ python-dotenv==1.0.0
31
+ rank-bm25==0.2.2
32
+ requests==2.31.0
33
+ requests-oauthlib==1.3.1
34
+ rich==13.7.0
35
+ sentence-transformers==2.2.2
36
+ streamlit==1.28.2
37
+ tiktoken==0.5.1
38
+ tokenizers==0.13.3
39
+ torch==2.0.1
40
+ tqdm==4.66.1
41
+ transformers==4.33.1
42
+ weaviate-client==3.25.3
43
+ polars>=0.19
44
+ plotly
45
+ angle-emb==0.1.5 # for UAE-Large-V1 model
46
+ streamlit-option-menu==0.3.6
47
+ hydralit_components==1.0.10
48
+ pathlib
49
+ gdown
50
+ modal
reranker.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+ from torch.nn import Sigmoid
3
+ from typing import List, Union
4
+ import numpy as np
5
+ from loguru import logger
6
+
7
+ class ReRanker(CrossEncoder):
8
+ '''
9
+ Cross-Encoder models achieve higher performance than Bi-Encoders,
10
+ however, they do not scale well to large datasets. The lack of scalability
11
+ is due to the underlying cross-attention mechanism, which is computationally
12
+ expensive. Thus a Bi-Encoder is best used for 1st-stage document retrieval and
13
+ a Cross-Encoder is used to re-rank the retrieved documents.
14
+
15
+ https://www.sbert.net/examples/applications/cross-encoder/README.html
16
+ '''
17
+
18
+ def __init__(self,
19
+ model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2',
20
+ **kwargs
21
+ ):
22
+ super().__init__(model_name=model_name,
23
+ **kwargs)
24
+ self.model_name = model_name
25
+ self.score_field = 'cross_score'
26
+ self.activation_fct = Sigmoid()
27
+
28
+ def _cross_encoder_score(self,
29
+ results: List[dict],
30
+ query: str,
31
+ hit_field: str='content',
32
+ apply_sigmoid: bool=True,
33
+ return_scores: bool=False
34
+ ) -> Union[np.array, None]:
35
+ '''
36
+ Given a list of hits from a Retriever:
37
+ 1. Scores hits by passing query and results through CrossEncoder model.
38
+ 2. Adds cross-score key to results dictionary.
39
+ 3. If desired returns np.array of Cross Encoder scores.
40
+ '''
41
+ activation_fct = self.activation_fct if apply_sigmoid else None
42
+ #build query/content list
43
+ cross_inp = [[query, hit[hit_field]] for hit in results]
44
+ #get scores
45
+ cross_scores = self.predict(cross_inp, activation_fct=activation_fct)
46
+ for i, result in enumerate(results):
47
+ result[self.score_field]=cross_scores[i]
48
+
49
+ if return_scores:return cross_scores
50
+
51
+ def rerank(self,
52
+ results: List[dict],
53
+ query: str,
54
+ top_k: int=10,
55
+ apply_sigmoid: bool=True,
56
+ threshold: float=None
57
+ ) -> List[dict]:
58
+ '''
59
+ Given a list of hits from a Retriever:
60
+ 1. Scores hits by passing query and results through CrossEncoder model.
61
+ 2. Adds cross_score key to results dictionary.
62
+ 3. Returns reranked results limited by either a threshold value or top_k.
63
+
64
+ Args:
65
+ -----
66
+ results : List[dict]
67
+ List of results from the Weaviate client
68
+ query : str
69
+ User query
70
+ top_k : int=10
71
+ Number of results to return
72
+ apply_sigmoid : bool=True
73
+ Whether to apply sigmoid activation to cross-encoder scores. If False,
74
+ returns raw cross-encoder scores (logits).
75
+ threshold : float=None
76
+ Minimum cross-encoder score to return. If no hits are above threshold,
77
+ returns top_k hits.
78
+ '''
79
+ # Sort results by the cross-encoder scores
80
+ self._cross_encoder_score(results=results, query=query, apply_sigmoid=apply_sigmoid)
81
+
82
+ sorted_hits = sorted(results, key=lambda x: x[self.score_field], reverse=True)
83
+ if threshold or threshold == 0:
84
+ filtered_hits = [hit for hit in sorted_hits if hit[self.score_field] >= threshold]
85
+ if not any(filtered_hits):
86
+ logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
87
+ return sorted_hits[:top_k]
88
+ return filtered_hits
89
+ return sorted_hits[:top_k]
retrieval_evaluation.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #external files
2
+ from openai_interface import GPT_Turbo
3
+ from weaviate_interface import WeaviateClient
4
+ from llama_index.finetuning import EmbeddingQAFinetuneDataset
5
+ from prompt_templates import qa_generation_prompt
6
+ from reranker import ReRanker
7
+
8
+ #standard library imports
9
+ import json
10
+ import time
11
+ import uuid
12
+ import os
13
+ import re
14
+ import random
15
+ from datetime import datetime
16
+ from typing import List, Dict, Tuple, Union, Literal
17
+
18
+ #misc
19
+ from tqdm import tqdm
20
+
21
+
22
+ class QueryContextGenerator:
23
+ '''
24
+ Class designed for the generation of query/context pairs using a
25
+ Generative LLM. The LLM is used to generate questions from a given
26
+ corpus of text. The query/context pairs can be used to fine-tune
27
+ an embedding model using a MultipleNegativesRankingLoss loss function
28
+ or can be used to create evaluation datasets for retrieval models.
29
+ '''
30
+ def __init__(self, openai_key: str, model_id: str='gpt-3.5-turbo-0613'):
31
+ self.llm = GPT_Turbo(model=model_id, api_key=openai_key)
32
+
33
+ def clean_validate_data(self,
34
+ data: List[dict],
35
+ valid_fields: List[str]=['content', 'summary', 'guest', 'doc_id'],
36
+ total_chars: int=950
37
+ ) -> List[dict]:
38
+ '''
39
+ Strip original data chunks so they only contain valid_fields.
40
+ Remove any chunks less than total_chars in size. Prevents LLM
41
+ from asking questions from sparse content.
42
+ '''
43
+ clean_docs = [{k:v for k,v in d.items() if k in valid_fields} for d in data]
44
+ valid_docs = [d for d in clean_docs if len(d['content']) > total_chars]
45
+ return valid_docs
46
+
47
+ def train_val_split(self,
48
+ data: List[dict],
49
+ n_train_questions: int,
50
+ n_val_questions: int,
51
+ n_questions_per_chunk: int=2,
52
+ total_chars: int=950):
53
+ '''
54
+ Splits corpus into training and validation sets. Training and
55
+ validation samples are randomly selected from the corpus. total_chars
56
+ parameter is set based on pre-analysis of average doc length in the
57
+ training corpus.
58
+ '''
59
+ clean_data = self.clean_validate_data(data, total_chars=total_chars)
60
+ random.shuffle(clean_data)
61
+ train_index = n_train_questions//n_questions_per_chunk
62
+ valid_index = n_val_questions//n_questions_per_chunk
63
+ end_index = valid_index + train_index
64
+ if end_index > len(clean_data):
65
+ raise ValueError('Cannot create dataset with desired number of questions, try using a larger dataset')
66
+ train_data = clean_data[:train_index]
67
+ valid_data = clean_data[train_index:end_index]
68
+ print(f'Length Training Data: {len(train_data)}')
69
+ print(f'Length Validation Data: {len(valid_data)}')
70
+ return train_data, valid_data
71
+
72
+ def generate_qa_embedding_pairs(
73
+ self,
74
+ data: List[dict],
75
+ generate_prompt_tmpl: str=None,
76
+ num_questions_per_chunk: int = 2,
77
+ ) -> EmbeddingQAFinetuneDataset:
78
+ """
79
+ Generate query/context pairs from a list of documents. The query/context pairs
80
+ can be used for fine-tuning an embedding model using a MultipleNegativesRankingLoss
81
+ or can be used to create an evaluation dataset for retrieval models.
82
+
83
+ This function was adapted for this course from the llama_index.finetuning.common module:
84
+ https://github.com/run-llama/llama_index/blob/main/llama_index/finetuning/embeddings/common.py
85
+ """
86
+ generate_prompt_tmpl = qa_generation_prompt if not generate_prompt_tmpl else generate_prompt_tmpl
87
+ queries = {}
88
+ relevant_docs = {}
89
+ corpus = {chunk['doc_id'] : chunk['content'] for chunk in data}
90
+ for chunk in tqdm(data):
91
+ summary = chunk['summary']
92
+ guest = chunk['guest']
93
+ transcript = chunk['content']
94
+ node_id = chunk['doc_id']
95
+ query = generate_prompt_tmpl.format(summary=summary,
96
+ guest=guest,
97
+ transcript=transcript,
98
+ num_questions_per_chunk=num_questions_per_chunk)
99
+ try:
100
+ response = self.llm.get_chat_completion(prompt=query, temperature=0.1, max_tokens=100)
101
+ except Exception as e:
102
+ print(e)
103
+ continue
104
+ result = str(response).strip().split("\n")
105
+ questions = [
106
+ re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
107
+ ]
108
+ questions = [question for question in questions if len(question) > 0]
109
+
110
+ for question in questions:
111
+ question_id = str(uuid.uuid4())
112
+ queries[question_id] = question
113
+ relevant_docs[question_id] = [node_id]
114
+
115
+ # construct dataset
116
+ return EmbeddingQAFinetuneDataset(
117
+ queries=queries, corpus=corpus, relevant_docs=relevant_docs
118
+ )
119
+
120
+ def execute_evaluation(dataset: EmbeddingQAFinetuneDataset,
121
+ class_name: str,
122
+ retriever: WeaviateClient,
123
+ reranker: ReRanker=None,
124
+ alpha: float=0.5,
125
+ retrieve_limit: int=100,
126
+ top_k: int=5,
127
+ chunk_size: int=256,
128
+ hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
129
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']='all',
130
+ display_properties: List[str]=['doc_id', 'content'],
131
+ dir_outpath: str='./eval_results',
132
+ include_miss_info: bool=False,
133
+ user_def_params: dict=None
134
+ ) -> Union[dict, Tuple[dict, List[dict]]]:
135
+ '''
136
+ Given a dataset, a retriever, and a reranker, evaluate the performance of the retriever and reranker.
137
+ Returns a dict of kw, vector, and hybrid hit rates and mrr scores. If inlude_miss_info is True, will
138
+ also return a list of kw and vector responses and their associated queries that did not return a hit.
139
+
140
+ Args:
141
+ -----
142
+ dataset: EmbeddingQAFinetuneDataset
143
+ Dataset to be used for evaluation
144
+ class_name: str
145
+ Name of Class on Weaviate host to be used for retrieval
146
+ retriever: WeaviateClient
147
+ WeaviateClient object to be used for retrieval
148
+ reranker: ReRanker
149
+ ReRanker model to be used for results reranking
150
+ alpha: float=0.5
151
+ Weighting factor for BM25 and Vector search.
152
+ alpha can be any number from 0 to 1, defaulting to 0.5:
153
+ alpha = 0 executes a pure keyword search method (BM25)
154
+ alpha = 0.5 weighs the BM25 and vector methods evenly
155
+ alpha = 1 executes a pure vector search method
156
+ retrieve_limit: int=5
157
+ Number of documents to retrieve from Weaviate host
158
+ top_k: int=5
159
+ Number of top results to evaluate
160
+ chunk_size: int=256
161
+ Number of tokens used to chunk text
162
+ hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef']
163
+ List of keys to be used for retrieving HNSW Index parameters from Weaviate host
164
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']='all'
165
+ Type of search to be evaluated. Options are 'kw', 'vector', 'hybrid', or 'all'
166
+ display_properties: List[str]=['doc_id', 'content']
167
+ List of properties to be returned from Weaviate host for display in response
168
+ dir_outpath: str='./eval_results'
169
+ Directory path for saving results. Directory will be created if it does not
170
+ already exist.
171
+ include_miss_info: bool=False
172
+ Option to include queries and their associated search response values
173
+ for queries that are "total misses"
174
+ user_def_params : dict=None
175
+ Option for user to pass in a dictionary of user-defined parameters and their values.
176
+ Will be automatically added to the results_dict if correct type is passed.
177
+ '''
178
+
179
+ reranker_name = reranker.model_name if reranker else "None"
180
+
181
+ results_dict = {'n':retrieve_limit,
182
+ 'top_k': top_k,
183
+ 'alpha': alpha,
184
+ 'Retriever': retriever.model_name_or_path,
185
+ 'Ranker': reranker_name,
186
+ 'chunk_size': chunk_size,
187
+ 'kw_hit_rate': 0,
188
+ 'kw_mrr': 0,
189
+ 'vector_hit_rate': 0,
190
+ 'vector_mrr': 0,
191
+ 'hybrid_hit_rate':0,
192
+ 'hybrid_mrr': 0,
193
+ 'total_misses': 0,
194
+ 'total_questions':0
195
+ }
196
+ #add extra params to results_dict
197
+ results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
198
+
199
+ start = time.perf_counter()
200
+ miss_info = []
201
+ for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
202
+ results_dict['total_questions'] += 1
203
+ hit = False
204
+ #make Keyword, Vector, and Hybrid calls to Weaviate host
205
+ try:
206
+ kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
207
+ vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
208
+ hybrid_response = retriever.hybrid_search(request=q, class_name=class_name, alpha=alpha, limit=retrieve_limit, display_properties=display_properties)
209
+ #rerank returned responses if reranker is provided
210
+ if reranker:
211
+ kw_response = reranker.rerank(kw_response, q, top_k=top_k)
212
+ vector_response = reranker.rerank(vector_response, q, top_k=top_k)
213
+ hybrid_response = reranker.rerank(hybrid_response, q, top_k=top_k)
214
+
215
+ #collect doc_ids to check for document matches (include only results_top_k)
216
+ kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response[:top_k], 1)}
217
+ vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response[:top_k], 1)}
218
+ hybrid_doc_ids = {result['doc_id']:i for i, result in enumerate(hybrid_response[:top_k], 1)}
219
+
220
+ #extract doc_id for scoring purposes
221
+ doc_id = dataset.relevant_docs[query_id][0]
222
+
223
+ #increment hit_rate counters and mrr scores
224
+ if doc_id in kw_doc_ids:
225
+ results_dict['kw_hit_rate'] += 1
226
+ results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
227
+ hit = True
228
+ if doc_id in vector_doc_ids:
229
+ results_dict['vector_hit_rate'] += 1
230
+ results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
231
+ hit = True
232
+ if doc_id in hybrid_doc_ids:
233
+ results_dict['hybrid_hit_rate'] += 1
234
+ results_dict['hybrid_mrr'] += 1/hybrid_doc_ids[doc_id]
235
+ hit = True
236
+ # if no hits, let's capture that
237
+ if not hit:
238
+ results_dict['total_misses'] += 1
239
+ miss_info.append({'query': q,
240
+ 'answer': dataset.corpus[doc_id],
241
+ 'doc_id': doc_id,
242
+ 'kw_response': kw_response,
243
+ 'vector_response': vector_response,
244
+ 'hybrid_response': hybrid_response})
245
+ except Exception as e:
246
+ print(e)
247
+ continue
248
+
249
+ #use raw counts to calculate final scores
250
+ calc_hit_rate_scores(results_dict, search_type=search_type)
251
+ calc_mrr_scores(results_dict, search_type=search_type)
252
+
253
+ end = time.perf_counter() - start
254
+ print(f'Total Processing Time: {round(end/60, 2)} minutes')
255
+ record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
256
+
257
+ if include_miss_info:
258
+ return results_dict, miss_info
259
+ return results_dict
260
+
261
+ def calc_hit_rate_scores(results_dict: Dict[str, Union[str, int]],
262
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
263
+ ) -> None:
264
+ if search_type == 'all':
265
+ search_type = ['kw', 'vector', 'hybrid']
266
+ for prefix in search_type:
267
+ results_dict[f'{prefix}_hit_rate'] = round(results_dict[f'{prefix}_hit_rate']/results_dict['total_questions'],2)
268
+
269
+ def calc_mrr_scores(results_dict: Dict[str, Union[str, int]],
270
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
271
+ ) -> None:
272
+ if search_type == 'all':
273
+ search_type = ['kw', 'vector', 'hybrid']
274
+ for prefix in search_type:
275
+ results_dict[f'{prefix}_mrr'] = round(results_dict[f'{prefix}_mrr']/results_dict['total_questions'],2)
276
+
277
+ def create_dir(dir_path: str) -> None:
278
+ '''
279
+ Checks if directory exists, and creates new directory
280
+ if it does not exist
281
+ '''
282
+ if not os.path.exists(dir_path):
283
+ os.makedirs(dir_path)
284
+
285
+ def record_results(results_dict: Dict[str, Union[str, int]],
286
+ chunk_size: int,
287
+ dir_outpath: str='./eval_results',
288
+ as_text: bool=False
289
+ ) -> None:
290
+ '''
291
+ Write results to output file in either txt or json format
292
+
293
+ Args:
294
+ -----
295
+ results_dict: Dict[str, Union[str, int]]
296
+ Dictionary containing results of evaluation
297
+ chunk_size: int
298
+ Size of text chunks in tokens
299
+ dir_outpath: str
300
+ Path to output directory. Directory only, filename is hardcoded
301
+ as part of this function.
302
+ as_text: bool
303
+ If True, write results as text file. If False, write as json file.
304
+ '''
305
+ create_dir(dir_outpath)
306
+ time_marker = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
307
+ ext = 'txt' if as_text else 'json'
308
+ path = os.path.join(dir_outpath, f'retrieval_eval_{chunk_size}_{time_marker}.{ext}')
309
+ if as_text:
310
+ with open(path, 'a') as f:
311
+ f.write(f"{results_dict}\n")
312
+ else:
313
+ with open(path, 'w') as f:
314
+ json.dump(results_dict, f, indent=4)
315
+
316
+ def add_params(client: WeaviateClient,
317
+ class_name: str,
318
+ results_dict: dict,
319
+ param_options: dict,
320
+ hnsw_config_keys: List[str]
321
+ ) -> dict:
322
+ hnsw_params = {k:v for k,v in client.show_class_config(class_name)['vectorIndexConfig'].items() if k in hnsw_config_keys}
323
+ if hnsw_params:
324
+ results_dict = {**results_dict, **hnsw_params}
325
+ if param_options and isinstance(param_options, dict):
326
+ results_dict = {**results_dict, **param_options}
327
+ return results_dict
328
+
329
+
330
+
331
+
332
+
unitesting_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import urllib.request
3
+
4
+ def load_impact_theory_data():
5
+ '''
6
+ Loads impact_theory_data.json data by trying three options:
7
+ 1. Assumes user is in Google Colab environment and loads file from content dir.
8
+ 2. If 1st option doesn't work, assumes user is in course repo and loads from data dir.
9
+ 3. If 2nd option doesn't work, assumes user does not have direct access to data so
10
+ downloads data direct from course repo.
11
+ '''
12
+ try:
13
+ path = '/content/impact_theory_data.json'
14
+ with open(path) as f:
15
+ data = json.load(f)
16
+ return data
17
+ except Exception:
18
+ print(f"Data not available at {path}")
19
+ try:
20
+ path = './data/impact_theory_data.json'
21
+ with open(path) as f:
22
+ data = json.load(f)
23
+ print(f'OK, data available at {path}')
24
+ return data
25
+ except Exception:
26
+ print(f'Data not available at {path}, downloading from source')
27
+ try:
28
+ with urllib.request.urlopen("https://ra.githubusercontent.com/americanthinker/vectorsearch-applications/main/data/impact_theory_data.json") as url:
29
+ data = json.load(url)
30
+ return data
31
+ except Exception:
32
+ print('Data cannot be loaded from source, please move data file to one of these paths to run this test:\n\
33
+ 1. "/content/impact_theory_data.json" --> if you are in Google Colab\n\
34
+ 2. "./data/impact_theory_data.json" --> if you are in a local environment\n')
utilities/install_kernel.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo Installing Jupyter kernel named $1 with display name $2
4
+ ipython kernel install --name "$1" --user --display-name $2
weaviate_interface.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from weaviate import Client, AuthApiKey
2
+ from dataclasses import dataclass
3
+ from openai import OpenAI
4
+ from sentence_transformers import SentenceTransformer
5
+ from typing import List, Union, Callable
6
+ from torch import cuda
7
+ from tqdm import tqdm
8
+ import time
9
+
10
+ class WeaviateClient(Client):
11
+ '''
12
+ A python native Weaviate Client class that encapsulates Weaviate functionalities
13
+ in one object. Several convenience methods are added for ease of use.
14
+
15
+ Args
16
+ ----
17
+ api_key: str
18
+ The API key for the Weaviate Cloud Service (WCS) instance.
19
+ https://console.weaviate.cloud/dashboard
20
+
21
+ endpoint: str
22
+ The url endpoint for the Weaviate Cloud Service instance.
23
+
24
+ model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2'
25
+ The name or path of the SentenceTransformer model to use for vector search.
26
+ Will also support OpenAI text-embedding-ada-002 model. This param enables
27
+ the use of most leading models on MTEB Leaderboard:
28
+ https://huggingface.co/spaces/mteb/leaderboard
29
+ openai_api_key: str=None
30
+ The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model.
31
+ '''
32
+ def __init__(self,
33
+ api_key: str,
34
+ endpoint: str,
35
+ model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
36
+ openai_api_key: str=None,
37
+ **kwargs
38
+ ):
39
+ auth_config = AuthApiKey(api_key=api_key)
40
+ super().__init__(auth_client_secret=auth_config,
41
+ url=endpoint,
42
+ **kwargs)
43
+ self.model_name_or_path = model_name_or_path
44
+ self.openai_model = False
45
+ if self.model_name_or_path == 'text-embedding-ada-002':
46
+ if not openai_api_key:
47
+ raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}')
48
+ self.model = OpenAI(api_key=openai_api_key)
49
+ self.openai_model = True
50
+ else:
51
+ self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None
52
+
53
+ self.display_properties = ['title', 'video_id', 'length', 'thumbnail_url', 'views', 'episode_url', \
54
+ 'doc_id', 'guest', 'content'] # 'playlist_id', 'channel_id', 'author'
55
+
56
+ def show_classes(self) -> Union[List[dict], str]:
57
+ '''
58
+ Shows all available classes (indexes) on the Weaviate instance.
59
+ '''
60
+ classes = self.cluster.get_nodes_status()[0]['shards']
61
+ if classes:
62
+ return [d['class'] for d in classes]
63
+ else:
64
+ return "No classes found on cluster."
65
+
66
+ def show_class_info(self) -> Union[List[dict], str]:
67
+ '''
68
+ Shows all information related to the classes (indexes) on the Weaviate instance.
69
+ '''
70
+ classes = self.cluster.get_nodes_status()[0]['shards']
71
+ if classes:
72
+ return [d for d in classes]
73
+ else:
74
+ return "No classes found on cluster."
75
+
76
+ def show_class_properties(self, class_name: str) -> Union[dict, str]:
77
+ '''
78
+ Shows all properties of a class (index) on the Weaviate instance.
79
+ '''
80
+ classes = self.schema.get()
81
+ if classes:
82
+ all_classes = classes['classes']
83
+ for d in all_classes:
84
+ if d['class'] == class_name:
85
+ return d['properties']
86
+ return f'Class "{class_name}" not found on host'
87
+ return f'No Classes found on host'
88
+
89
+ def show_class_config(self, class_name: str) -> Union[dict, str]:
90
+ '''
91
+ Shows all configuration of a class (index) on the Weaviate instance.
92
+ '''
93
+ classes = self.schema.get()
94
+ if classes:
95
+ all_classes = classes['classes']
96
+ for d in all_classes:
97
+ if d['class'] == class_name:
98
+ return d
99
+ return f'Class "{class_name}" not found on host'
100
+ return f'No Classes found on host'
101
+
102
+ def delete_class(self, class_name: str) -> str:
103
+ '''
104
+ Deletes a class (index) on the Weaviate instance, if it exists.
105
+ '''
106
+ available = self._check_class_avialability(class_name)
107
+ if isinstance(available, bool):
108
+ if available:
109
+ self.schema.delete_class(class_name)
110
+ not_deleted = self._check_class_avialability(class_name)
111
+ if isinstance(not_deleted, bool):
112
+ if not_deleted:
113
+ return f'Class "{class_name}" was not deleted. Try again.'
114
+ else:
115
+ return f'Class "{class_name}" deleted'
116
+ return f'Class "{class_name}" deleted and there are no longer any classes on host'
117
+ return f'Class "{class_name}" not found on host'
118
+ return available
119
+
120
+ def _check_class_avialability(self, class_name: str) -> Union[bool, str]:
121
+ '''
122
+ Checks if a class (index) exists on the Weaviate instance.
123
+ '''
124
+ classes = self.schema.get()
125
+ if classes:
126
+ all_classes = classes['classes']
127
+ for d in all_classes:
128
+ if d['class'] == class_name:
129
+ return True
130
+ return False
131
+ else:
132
+ return f'No Classes found on host'
133
+
134
+ def format_response(self,
135
+ response: dict,
136
+ class_name: str
137
+ ) -> List[dict]:
138
+ '''
139
+ Formats json response from Weaviate into a list of dictionaries.
140
+ Expands _additional fields if present into top-level dictionary.
141
+ '''
142
+ if response.get('errors'):
143
+ return response['errors'][0]['message']
144
+ results = []
145
+ hits = response['data']['Get'][class_name]
146
+ for d in hits:
147
+ temp = {k:v for k,v in d.items() if k != '_additional'}
148
+ if d.get('_additional'):
149
+ for key in d['_additional']:
150
+ temp[key] = d['_additional'][key]
151
+ results.append(temp)
152
+ return results
153
+
154
+ def update_ef_value(self, class_name: str, ef_value: int) -> str:
155
+ '''
156
+ Updates ef_value for a class (index) on the Weaviate instance.
157
+ '''
158
+ self.schema.update_config(class_name=class_name, config={'vectorIndexConfig': {'ef': ef_value}})
159
+ print(f'ef_value updated to {ef_value} for class {class_name}')
160
+ return self.show_class_config(class_name)['vectorIndexConfig']
161
+
162
+ def keyword_search(self,
163
+ request: str,
164
+ class_name: str,
165
+ properties: List[str]=['content'],
166
+ limit: int=10,
167
+ where_filter: dict=None,
168
+ display_properties: List[str]=None,
169
+ return_raw: bool=False) -> Union[dict, List[dict]]:
170
+ '''
171
+ Executes Keyword (BM25) search.
172
+
173
+ Args
174
+ ----
175
+ query: str
176
+ User query.
177
+ class_name: str
178
+ Class (index) to search.
179
+ properties: List[str]
180
+ List of properties to search across.
181
+ limit: int=10
182
+ Number of results to return.
183
+ display_properties: List[str]=None
184
+ List of properties to return in response.
185
+ If None, returns all properties.
186
+ return_raw: bool=False
187
+ If True, returns raw response from Weaviate.
188
+ '''
189
+ display_properties = display_properties if display_properties else self.display_properties
190
+ response = (self.query
191
+ .get(class_name, display_properties)
192
+ .with_bm25(query=request, properties=properties)
193
+ .with_additional(['score', "id"])
194
+ .with_limit(limit)
195
+ )
196
+ response = response.with_where(where_filter).do() if where_filter else response.do()
197
+ if return_raw:
198
+ return response
199
+ else:
200
+ return self.format_response(response, class_name)
201
+
202
+ def vector_search(self,
203
+ request: str,
204
+ class_name: str,
205
+ limit: int=10,
206
+ where_filter: dict=None,
207
+ display_properties: List[str]=None,
208
+ return_raw: bool=False,
209
+ device: str='cuda:0' if cuda.is_available() else 'cpu'
210
+ ) -> Union[dict, List[dict]]:
211
+ '''
212
+ Executes vector search using embedding model defined on instantiation
213
+ of WeaviateClient instance.
214
+
215
+ Args
216
+ ----
217
+ query: str
218
+ User query.
219
+ class_name: str
220
+ Class (index) to search.
221
+ limit: int=10
222
+ Number of results to return.
223
+ display_properties: List[str]=None
224
+ List of properties to return in response.
225
+ If None, returns all properties.
226
+ return_raw: bool=False
227
+ If True, returns raw response from Weaviate.
228
+ '''
229
+ display_properties = display_properties if display_properties else self.display_properties
230
+ query_vector = self._create_query_vector(request, device=device)
231
+ response = (
232
+ self.query
233
+ .get(class_name, display_properties)
234
+ .with_near_vector({"vector": query_vector})
235
+ .with_limit(limit)
236
+ .with_additional(['distance'])
237
+ )
238
+ response = response.with_where(where_filter).do() if where_filter else response.do()
239
+ if return_raw:
240
+ return response
241
+ else:
242
+ return self.format_response(response, class_name)
243
+
244
+ def _create_query_vector(self, query: str, device: str) -> List[float]:
245
+ '''
246
+ Creates embedding vector from text query.
247
+ '''
248
+ return self.get_openai_embedding(query) if self.openai_model else self.model.encode(query, device=device).tolist()
249
+
250
+ def get_openai_embedding(self, query: str) -> List[float]:
251
+ '''
252
+ Gets embedding from OpenAI API for query.
253
+ '''
254
+ embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump()
255
+ if embedding:
256
+ return embedding['data'][0]['embedding']
257
+ else:
258
+ raise ValueError(f'No embedding found for query: {query}')
259
+
260
+ def hybrid_search(self,
261
+ request: str,
262
+ class_name: str,
263
+ properties: List[str]=['content'],
264
+ alpha: float=0.5,
265
+ limit: int=10,
266
+ where_filter: dict=None,
267
+ display_properties: List[str]=None,
268
+ return_raw: bool=False,
269
+ device: str='cuda:0' if cuda.is_available() else 'cpu'
270
+ ) -> Union[dict, List[dict]]:
271
+ '''
272
+ Executes Hybrid (BM25 + Vector) search.
273
+
274
+ Args
275
+ ----
276
+ query: str
277
+ User query.
278
+ class_name: str
279
+ Class (index) to search.
280
+ properties: List[str]
281
+ List of properties to search across (using BM25)
282
+ alpha: float=0.5
283
+ Weighting factor for BM25 and Vector search.
284
+ alpha can be any number from 0 to 1, defaulting to 0.5:
285
+ alpha = 0 executes a pure keyword search method (BM25)
286
+ alpha = 0.5 weighs the BM25 and vector methods evenly
287
+ alpha = 1 executes a pure vector search method
288
+ limit: int=10
289
+ Number of results to return.
290
+ display_properties: List[str]=None
291
+ List of properties to return in response.
292
+ If None, returns all properties.
293
+ return_raw: bool=False
294
+ If True, returns raw response from Weaviate.
295
+ '''
296
+ display_properties = display_properties if display_properties else self.display_properties
297
+ query_vector = self._create_query_vector(request, device=device)
298
+ response = (
299
+ self.query
300
+ .get(class_name, display_properties)
301
+ .with_hybrid(query=request,
302
+ alpha=alpha,
303
+ vector=query_vector,
304
+ properties=properties,
305
+ fusion_type='relativeScoreFusion') #hard coded option for now
306
+ .with_additional(["score", "explainScore"])
307
+ .with_limit(limit)
308
+ )
309
+
310
+ response = response.with_where(where_filter).do() if where_filter else response.do()
311
+ if return_raw:
312
+ return response
313
+ else:
314
+ return self.format_response(response, class_name)
315
+
316
+
317
+ class WeaviateIndexer:
318
+
319
+ def __init__(self,
320
+ client: WeaviateClient,
321
+ batch_size: int=150,
322
+ num_workers: int=4,
323
+ dynamic: bool=True,
324
+ creation_time: int=5,
325
+ timeout_retries: int=3,
326
+ connection_error_retries: int=3,
327
+ callback: Callable=None,
328
+ ):
329
+ '''
330
+ Class designed to batch index documents into Weaviate. Instantiating
331
+ this class will automatically configure the Weaviate batch client.
332
+ '''
333
+ self._client = client
334
+ self._callback = callback if callback else self._default_callback
335
+
336
+ self._client.batch.configure(batch_size=batch_size,
337
+ num_workers=num_workers,
338
+ dynamic=dynamic,
339
+ creation_time=creation_time,
340
+ timeout_retries=timeout_retries,
341
+ connection_error_retries=connection_error_retries,
342
+ callback=self._callback
343
+ )
344
+
345
+ def _default_callback(self, results: dict):
346
+ """
347
+ Check batch results for errors.
348
+
349
+ Parameters
350
+ ----------
351
+ results : dict
352
+ The Weaviate batch creation return value.
353
+ """
354
+
355
+ if results is not None:
356
+ for result in results:
357
+ if "result" in result and "errors" in result["result"]:
358
+ if "error" in result["result"]["errors"]:
359
+ print(result["result"])
360
+
361
+ def batch_index_data(self,
362
+ data: List[dict],
363
+ class_name: str,
364
+ vector_property: str='content_embedding'
365
+ ) -> None:
366
+ '''
367
+ Batch function for fast indexing of data onto Weaviate cluster.
368
+ This method assumes that self._client.batch is already configured.
369
+ '''
370
+ start = time.perf_counter()
371
+ with self._client.batch as batch:
372
+ for d in tqdm(data):
373
+
374
+ #define single document
375
+ properties = {k:v for k,v in d.items() if k != vector_property}
376
+ try:
377
+ #add data object to batch
378
+ batch.add_data_object(
379
+ data_object=properties,
380
+ class_name=class_name,
381
+ vector=d[vector_property]
382
+ )
383
+ except Exception as e:
384
+ print(e)
385
+ continue
386
+
387
+ end = time.perf_counter() - start
388
+
389
+ print(f'Batch job completed in {round(end/60, 2)} minutes.')
390
+ class_info = self._client.show_class_info()
391
+ for i, c in enumerate(class_info):
392
+ if c['class'] == class_name:
393
+ print(class_info[i])
394
+ self._client.batch.shutdown()
395
+
396
+ @dataclass
397
+ class WhereFilter:
398
+
399
+ '''
400
+ Simplified interface for constructing a WhereFilter object.
401
+
402
+ Args
403
+ ----
404
+ path: List[str]
405
+ List of properties to filter on.
406
+ operator: str
407
+ Operator to use for filtering. Options: ['And', 'Or', 'Equal', 'NotEqual',
408
+ 'GreaterThan', 'GreaterThanEqual', 'LessThan', 'LessThanEqual', 'Like',
409
+ 'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
410
+ value[dataType]: Union[int, bool, str, float, datetime]
411
+ Value to filter on. The dataType suffix must match the data type of the
412
+ property being filtered on. At least and only one value type must be provided.
413
+ '''
414
+ path: List[str]
415
+ operator: str
416
+ valueInt: int=None
417
+ valueBoolean: bool=None
418
+ valueText: str=None
419
+ valueNumber: float=None
420
+ valueDate = None
421
+
422
+ def post_init(self):
423
+ operators = ['And', 'Or', 'Equal', 'NotEqual','GreaterThan', 'GreaterThanEqual', 'LessThan',\
424
+ 'LessThanEqual', 'Like', 'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
425
+ if self.operator not in operators:
426
+ raise ValueError(f'operator must be one of: {operators}, got {self.operator}')
427
+ values = [self.valueInt, self.valueBoolean, self.valueText, self.valueNumber, self.valueDate]
428
+ if not any(values):
429
+ raise ValueError('At least one value must be provided.')
430
+ if len(values) > 1:
431
+ raise ValueError('At most one value can be provided.')
432
+
433
+ def todict(self):
434
+ return {k:v for k,v in self.__dict__.items() if v is not None}