Spaces:
Sleeping
Sleeping
#%% | |
from tiktoken import get_encoding, encoding_for_model | |
from weaviate_interface import WeaviateClient, WhereFilter | |
from sentence_transformers import SentenceTransformer | |
from prompt_templates import question_answering_prompt_series, question_answering_system | |
from openai_interface import GPT_Turbo | |
from app_features import (convert_seconds, generate_prompt_series, search_result, | |
validate_token_threshold, load_content_cache, load_data, | |
expand_content) | |
from retrieval_evaluation import execute_evaluation, calc_hit_rate_scores | |
from llama_index.finetuning import EmbeddingQAFinetuneDataset | |
from weaviate_interface import WeaviateClient | |
from openai import BadRequestError | |
from reranker import ReRanker | |
from loguru import logger | |
import streamlit as st | |
from streamlit_option_menu import option_menu | |
import hydralit_components as hc | |
import sys | |
import json | |
import os, time, requests, re | |
from datetime import timedelta | |
import pathlib | |
import gdown | |
import tempfile | |
import base64 | |
import shutil | |
def get_base64_of_bin_file(bin_file): | |
with open(bin_file, 'rb') as file: | |
data = file.read() | |
return base64.b64encode(data).decode() | |
from dotenv import load_dotenv, find_dotenv | |
load_dotenv(find_dotenv('env'), override=True) | |
# I use a key that I increment each time I want to change a text_input | |
if 'key' not in st.session_state: | |
st.session_state.key = 0 | |
# key = st.session_state['key'] | |
if not pathlib.Path('models').exists(): | |
os.mkdir('models') | |
# I should cache these things but no time left | |
# I put a file local.txt in my desktop models folder to find out if it's running online | |
we_are_online = not pathlib.Path("models/local.txt").exists() | |
we_are_not_online = not we_are_online | |
golden_dataset = EmbeddingQAFinetuneDataset.from_json("data/golden_100.json") | |
# shutil.rmtree("models/models") # remove it - I wanted to clear the space on streamlit online | |
## PAGE CONFIGURATION | |
st.set_page_config(page_title="Ask Impact Theory", | |
page_icon="assets/impact-theory-logo-only.png", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
menu_items={'Report a bug': "https://www.extremelycoolapp.com/bug"}) | |
image = "https://is2-ssl.mzstatic.com/image/thumb/Music122/v4/bd/34/82/bd348260-314c-5898-26c0-bef2e0388ebe/source/1200x1200bb.png" | |
def add_bg_from_local(image_file): | |
bin_str = get_base64_of_bin_file(image_file) | |
page_bg_img = f''' | |
<style> | |
.stApp {{ | |
background-image: url("data:image/png;base64,{bin_str}"); | |
background-size: 100% auto; | |
background-repeat: no-repeat; | |
background-attachment: fixed; | |
}} | |
</style> | |
''' | |
st.markdown(page_bg_img, unsafe_allow_html=True) | |
# COMMENT: I tried to create a dropdown menu but it's harder than it looks, so I gave up | |
# https://discuss.streamlit.io/t/streamlit-option-menu-is-a-simple-streamlit-component-that-allows-users-to-select-a-single-item-from-a-list-of-options-in-a-menu/20514 | |
# not great, but it works | |
# selected = option_menu("About", ["Improvements","This"], #"Main Menu", ["Home", 'Settings'], | |
# icons=['house', 'gear'], | |
# menu_icon="cast", | |
# default_index=1) | |
# # Custom HTML/CSS for the banner | |
# base64_img = get_base64_of_bin_file("assets/it_tom_bilyeu.png") | |
# banner_menu_html = f""" | |
# <div class="banner"> | |
# <img src= "data:image/png;base64,{base64_img}" alt="Banner Image"> | |
# </div> | |
# <style> | |
# .banner {{ | |
# width: 100%; | |
# height: auto; | |
# overflow: hidden; | |
# display: flex; | |
# justify-content: center; | |
# }} | |
# .banner img {{ | |
# width: 130%; | |
# height: auto; | |
# object-fit: contain; | |
# }} | |
# </style> | |
# """ | |
# st.components.v1.html(banner_menu_html) | |
# specify the primary menu definition | |
# it gives a vertical menu inside a navigation bar !!! | |
# menu_data = [ | |
# {'icon': "far fa-copy", 'label':"Left End"}, | |
# {'id':'Copy','icon':"🐙",'label':"Copy"}, | |
# {'icon': "far fa-chart-bar", 'label':"Chart"},#no tooltip message | |
# {'icon': "far fa-address-book", 'label':"Book"}, | |
# {'id':' Crazy return value 💀','icon': "💀", 'label':"Calendar"}, | |
# {'icon': "far fa-clone", 'label':"Component"}, | |
# {'icon': "fas fa-tachometer-alt", 'label':"Dashboard",'ttip':"I'm the Dashboard tooltip!"}, #can add a tooltip message | |
# {'icon': "far fa-copy", 'label':"Right End"}, | |
# ] | |
# # we can override any part of the primary colors of the menu | |
# over_theme = {'txc_inactive': '#FFFFFF','menu_background':'red','txc_active':'yellow','option_active':'blue'} | |
# # over_theme = {'txc_inactive': '#FFFFFF'} | |
# menu_id = hc.nav_bar(menu_definition=menu_data, | |
# home_name='Home', | |
# override_theme=over_theme) | |
#get the id of the menu item clicked | |
# st.info(f"{menu_id=}") | |
## RERANKER | |
reranker = ReRanker('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
## ENCODING --> tiktoken library | |
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613'] | |
model_nameGPT = model_ids[1] | |
encoding = encoding_for_model(model_nameGPT) | |
# = get_encoding('gpt-3.5-turbo-0613') | |
############## | |
data_path = './data/impact_theory_data.json' | |
cache_path = 'data/impact_theory_cache.parquet' | |
data = load_data(data_path) | |
cache = None # load_content_cache(cache_path) | |
try: | |
# st.write("Loading secrets from secrets.toml") | |
Wapi_key = st.secrets['secrets']['WEAVIATE_API_KEY'] | |
url = st.secrets['secrets']['WEAVIATE_ENDPOINT'] | |
openai_api_key = st.secrets['secrets']['OPENAI_API_KEY'] | |
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris'] | |
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT'] | |
# st.write("Secrets loaded from secrets.toml") | |
# st.write("HF_TOKEN", hf_token) | |
except: | |
st.write("Loading secrets from environment variables") | |
api_key = os.environ['WEAVIATE_API_KEY'] | |
url = os.environ['WEAVIATE_ENDPOINT'] | |
openai_api_key = os.environ['OPENAI_API_KEY'] | |
hf_token = os.environ['LLAMA2_ENDPOINT_HF_TOKEN_chris'] | |
hf_endpoint = os.environ['LLAMA2_ENDPOINT_UPLIMIT'] | |
#%% | |
# model_default = 'sentence-transformers/all-mpnet-base-v2' | |
model_default = 'models/finetuned-all-mpnet-base-v2-300' if we_are_not_online \ | |
else 'sentence-transformers/all-mpnet-base-v2' | |
available_models = ['sentence-transformers/all-mpnet-base-v2', | |
'sentence-transformers/all-MiniLM-L6-v2', | |
'models/finetuned-all-mpnet-base-v2-300'] | |
#%% | |
models_urls = {'models/finetuned-all-mpnet-base-v2-300': "https://drive.google.com/drive/folders/1asJ37-AUv5nytLtH6hp6_bVV3_cZOXfj"} | |
def download_model_from_Gdrive(model_name_or_path, model_full_path): | |
print("Downloading model from Google Drive") | |
st.write("Downloading model from Google Drive") | |
assert model_name_or_path in models_urls, f"Model {model_name_or_path} not found in models_urls" | |
url = models_urls[model_name_or_path] | |
gdown.download_folder(url, output=model_full_path, quiet=False, use_cookies=False) | |
print("Model downloaded and saved to models folder") | |
# st.write("Model downloaded") | |
def download_model(model_name_or_path, model_full_path): | |
if model_name_or_path.startswith("models/"): | |
download_model_from_Gdrive(model_name_or_path, model_full_path) | |
print(f"Model {model_full_path} downloaded") | |
models_urls[model_name_or_path] = model_full_path | |
# st.sidebar.write(f"Model {model_full_path} downloaded") | |
elif model_name_or_path.startswith("sentence-transformers/"): | |
st.sidebar.write(f"Downloading Sentence Transformer model {model_name_or_path}") | |
model = SentenceTransformer(model_name_or_path) # HF looks into its own models folder/path | |
models_urls[model_name_or_path] = model_full_path | |
# st.sidebar.write(f"Model {model_name_or_path} downloaded") | |
model.save(model_full_path) | |
# st.sidebar.write(f"Model {model_name_or_path} saved to {model_full_path}") | |
# if 'modelspath' not in st.session_state: | |
# st.session_state['modelspath'] = None | |
# if st.session_state.modelspath is None: | |
# # let's create a temp folder on the first run | |
# persistent_dir = pathlib.Path("path/to/persistent_dir") | |
# persistent_dir.mkdir(parents=True, exist_ok=True) | |
# with tempfile.TemporaryDirectory() as temp_dir: | |
# st.session_state.modelspath = temp_dir | |
# print(f"Temporary directory created at {temp_dir}") | |
# # the temp folder disappears with the context, but not the one we've created manually | |
# else: | |
# temp_dir = st.session_state.modelspath | |
# print(f"Temporary directory already exists at {temp_dir}") | |
# # st.write(os.listdir(temp_dir)) | |
#%% | |
# for streamlit online, we must download the model from google drive | |
# because github LFS doesn't work on forked repos | |
def check_model(model_name_or_path): | |
model_path = pathlib.Path(model_name_or_path) | |
model_full_path = str(pathlib.Path("models") / model_path) # this creates a models folder inside /models | |
model_full_path = model_full_path.replace("sentence-transformers/", "models/") # all are saved in models folder | |
if pathlib.Path(model_full_path).exists(): | |
# let's use the model that's already there | |
print(f"Model {model_full_path} already exists") | |
# but delete everything else in we are online because | |
# streamlit online has limited space (and will shut down the app if it's full) | |
if we_are_online: | |
# st.sidebar.write(f"Model {model_full_path} already exists") | |
# st.sidebar.write(f"Deleting other models") | |
dirs = os.listdir("models/models") | |
# we get only the folder name, not the full path | |
dirs.remove(model_full_path.split('/')[-1]) | |
for p in dirs: | |
dirpath = pathlib.Path("models/models") / p | |
if dirpath.is_dir(): | |
shutil.rmtree(dirpath) | |
else: | |
if we_are_online: | |
# space issues on streamlit online, let's not leave anything behind | |
# and redownload the model eveery time | |
print("Deleting models/models folder") | |
if pathlib.Path('models/models').exists(): | |
shutil.rmtree("models/models") # make room, if other models are there | |
# st.sidebar.write(f"models/models folder deleted") | |
download_model(model_name_or_path, model_full_path) | |
return model_full_path | |
#%% instantiate Weaviate client | |
def get_weaviate_client(api_key, url, model_name_or_path, openai_api_key): | |
client = WeaviateClient(api_key, url, | |
model_name_or_path=model_name_or_path, | |
openai_api_key=openai_api_key) | |
client.display_properties.append('summary') | |
available_classes = sorted(client.show_classes()) | |
# st.write(f"Available classes: {available_classes}") | |
# st.write(f"Available classes type: {type(available_classes)}") | |
logger.info(available_classes) | |
return client, available_classes | |
############## | |
# data = load_data(data_path) | |
# guests list for sidebar | |
guest_list = sorted(list(set([d['guest'] for d in data]))) | |
def main(): | |
with st.sidebar: | |
# moved it to main area | |
# guest = st.selectbox('Select Guest', | |
# options=guest_list, | |
# index=None, | |
# placeholder='Select Guest') | |
_, center, _ = st.columns([3, 5, 3]) | |
with center: | |
st.text("Search Lab") | |
_, center, _ = st.columns([2, 5, 3]) | |
with center: | |
if we_are_online: | |
st.text("Running ONLINE") | |
st.text("(UNSTABLE)") | |
else: | |
st.text("Running OFFLINE") | |
st.write("----------") | |
alpha_input = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05) | |
retrieval_limit = st.slider(label='Hybrid Search Results', min_value=10, max_value=300, value=10, step=10) | |
hybrid_filter = st.toggle('Filter Guest', True) # i.e. look only at guests' data | |
rerank = st.toggle('Use Reranker', True) | |
if rerank: | |
reranker_topk = st.slider(label='Reranker Top K',min_value=1, max_value=5, value=3, step=1) | |
else: | |
# needed to not fill the LLM with too many responses (> context size) | |
# we could make it dependent on the model | |
reranker_topk = 3 | |
rag_it = st.toggle('RAG it', True) | |
if rag_it: | |
st.sidebar.write(f"Using LLM '{model_nameGPT}'") | |
llm_temperature = st.slider(label='LLM T˚', min_value=0.0, max_value=2.0, value=0.01, step=0.10 ) | |
model_name_or_path = st.selectbox(label='Model Name:', options=available_models, | |
index=available_models.index(model_default), | |
placeholder='Select Model') | |
st.write("Experimental and time limited 2'") | |
finetune_model = st.toggle('Finetune on Modal A100 GPU', False) | |
if finetune_model: | |
from finetune_backend import finetune | |
if 'finetuned' in model_name_or_path: | |
st.write("Model already finetuned") | |
elif model_name_or_path.startswith("models/"): | |
st.write("Sentence Transformers models only!") | |
else: | |
try: | |
if 'finetuned' in finetune_model: | |
st.write("Model already finetuned") | |
else: | |
model_path = finetune(model_name_or_path, savemodel=True, outpath='models') | |
if model_path is not None: | |
if finetune_model.split('/')[-1] not in model_path: | |
st.write(model_path) # a warning from finetuning in this case | |
elif model_path not in available_models: | |
# finetuning generated a model, let's add it | |
available_models.append(model_path) | |
st.write("Model saved!") | |
except Exception: | |
st.write("Model not found on HF or error") | |
model_name_or_path = check_model(model_name_or_path) | |
client, available_classes = get_weaviate_client(Wapi_key, url, model_name_or_path, openai_api_key) | |
start_class = 'Impact_theory_all_mpnet_base_v2_finetuned' | |
class_name = st.selectbox( | |
label='Class Name:', | |
options=available_classes, | |
index=available_classes.index(start_class), | |
placeholder='Select Class Name' | |
) | |
st.write("----------") | |
c1,c2 = st.columns([8,1]) | |
with c1: | |
show_metrics = st.toggle('Show Metrics on Golden set', False) | |
if show_metrics: | |
# _, center, _ = st.columns([3, 5, 3]) | |
# with center: | |
# st.text("Metrics") | |
with c2: | |
with st.spinner(''): | |
metrics = execute_evaluation(golden_dataset, class_name, client, alpha=alpha_input) | |
if show_metrics: | |
kw_hit_rate = metrics['kw_hit_rate'] | |
kw_mrr = metrics['kw_mrr'] | |
hybrid_hit_rate = metrics['hybrid_hit_rate'] | |
vector_hit_rate = metrics['vector_hit_rate'] | |
vector_mrr = metrics['vector_mrr'] | |
total_misses = metrics['total_misses'] | |
st.text(f"KW hit rate: {kw_hit_rate}") | |
st.text(f"Vector hit rate: {vector_hit_rate}") | |
st.text(f"Hybrid hit rate: {hybrid_hit_rate}") | |
st.text(f"Hybrid MRR: {vector_mrr}") | |
st.text(f"Total misses: {total_misses}") | |
st.write("----------") | |
st.title("Chat with the Impact Theory podcasts!") | |
# st.image('./assets/impact-theory-logo.png', width=400) | |
st.image('assets/it_tom_bilyeu.png', use_column_width=True) | |
# st.subheader(f"Chat with the Impact Theory podcast: ") | |
st.write('\n') | |
# st.stop() | |
st.write("\u21D0 Open the sidebar to change Search settings \n ") # https://home.unicode.org also 21E0, 21B0 B2 D0 | |
guest = st.selectbox('Select A Guest', | |
options=guest_list, | |
index=None, | |
placeholder='Select Guest') | |
col1, col2 = st.columns([7,3]) | |
with col1: | |
if guest is None: | |
msg = f'Select a guest before asking your question:' | |
else: | |
msg = f'Enter your question about {guest}:' | |
textbox = st.empty() | |
# best solution I found to be able to change the text inside a text_input box afterwards, using a key | |
query = textbox.text_input(msg, | |
value="", | |
placeholder="You can refer to the guest with pronoun or drop the question mark", | |
key=st.session_state.key) | |
# st.write(f"Guest = {guest}") | |
# st.write(f"key = {st.session_state.key}") | |
st.write('\n\n\n\n\n') | |
reworded_query = {'changed': False, 'status': 'error'} # at start, the query is empty | |
valid_response = [] # at start, the query is empty, so prevent the search | |
if query: | |
if guest is None: | |
st.session_state.key += 1 | |
query = textbox.text_input(msg, | |
value="", | |
placeholder="YOU MUST SELECT A GUEST BEFORE ASKING A QUESTION", | |
key=st.session_state.key) | |
# st.write(f"key = {st.session_state.key}") | |
st.stop() | |
else: | |
# st.write(f'It looks like you selected {guest} as a filter (It is ignored for now).') | |
with col2: | |
# let's add a nice pulse bar while generating the response | |
with hc.HyLoader('', hc.Loaders.pulse_bars, primary_color= 'red', height=50): #"#0e404d" for image green | |
# with st.spinner('Generating Response...'): | |
with col1: | |
# let's use Llama2 here | |
reworded_query = reword_query(query, guest, | |
model_name='llama2-13b-chat') | |
query = reworded_query['rewritten_question'] | |
# we can arrive here only if a guest was selected | |
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \ | |
if hybrid_filter else None | |
hybrid_response = client.hybrid_search(query, | |
class_name, | |
# properties=['content'], #['title', 'summary', 'content'], | |
alpha=alpha_input, | |
display_properties=client.display_properties, | |
where_filter=where_filter, | |
limit=retrieval_limit) | |
response = hybrid_response | |
if rerank: | |
# rerank results with cross encoder | |
ranked_response = reranker.rerank(response, query, | |
apply_sigmoid=True, # score between 0 and 1 | |
top_k=reranker_topk) | |
logger.info(ranked_response) | |
expanded_response = expand_content(ranked_response, cache, | |
content_key='doc_id', | |
create_new_list=True) | |
response = expanded_response | |
# make sure token count < threshold | |
token_threshold = 8000 if model_nameGPT == model_ids[0] else 3500 | |
valid_response = validate_token_threshold(response, | |
question_answering_prompt_series, | |
query=query, | |
tokenizer= encoding,# variable from ENCODING, | |
token_threshold=token_threshold, | |
verbose=True) | |
# st.write(f"Number of results: {len(valid_response)}") | |
# I jump out of col1 to get all page width, so need to retest query | |
if query is not None and reworded_query['status'] != 'error': | |
show_query = st.toggle('Show rewritten query', False) | |
if show_query: # or reworded_query['changed']: | |
st.write(f"Rewritten query: {query}") | |
# creates container for LLM response to position it above search results | |
chat_container, response_box = [], st.empty() | |
# # RAG time !! execute chat call to LLM | |
if rag_it: | |
# st.subheader("Response from Impact Theory (context)") | |
# will appear under the answer, moved it into the response box | |
# generate LLM prompt | |
prompt = generate_prompt_series(query=query, results=valid_response) | |
GPTllm = GPT_Turbo(model=model_nameGPT, | |
api_key=st.secrets['secrets']['OPENAI_API_KEY']) | |
try: | |
# inserts chat stream from LLM | |
for resp in GPTllm.get_chat_completion(prompt=prompt, | |
temperature=llm_temperature, | |
max_tokens=350, | |
show_response=True, | |
stream=True): | |
with response_box: | |
content = resp.choices[0].delta.content | |
if content: | |
chat_container.append(content) | |
result = "".join(chat_container).strip() | |
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}") | |
except BadRequestError as e: | |
logger.info('Making request with smaller context') | |
valid_response = validate_token_threshold(response, | |
question_answering_prompt_series, | |
query=query, | |
tokenizer=encoding, | |
token_threshold=3500, | |
verbose=True) | |
# if reranker is off, we may receive a LOT of responses | |
# so we must reduce the context size manually | |
if not rerank: | |
valid_response = valid_response[:reranker_topk] | |
prompt = generate_prompt_series(query=query, results=valid_response) | |
for resp in GPTllm.get_chat_completion(prompt=prompt, | |
temperature=llm_temperature, | |
max_tokens=350, # expand for more verbose answers | |
show_response=True, | |
stream=True): | |
try: | |
# inserts chat stream from LLM | |
with response_box: | |
content = resp.choice[0].delta.content | |
if content: | |
chat_container.append(content) | |
result = "".join(chat_container).strip() | |
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}") | |
except Exception as e: | |
print(e) | |
st.markdown("----") | |
st.subheader("Search Results") | |
for i, hit in enumerate(valid_response): | |
col1, col2 = st.columns([7, 3], gap='large') | |
image = hit['thumbnail_url'] # get thumbnail_url | |
episode_url = hit['episode_url'] # get episode_url | |
title = hit["title"] # get title | |
show_length = hit["length"] # get length | |
time_string = str(timedelta(seconds=show_length)) # convert show_length to readable time string | |
with col1: | |
st.write(search_result(i=i, | |
url=episode_url, | |
guest=hit['guest'], | |
title=title, | |
content='', | |
length=time_string), | |
unsafe_allow_html=True) | |
st.write('\n\n') | |
with col2: | |
#st.write(f"<a href={episode_url} <img src={image} width='200'></a>", | |
# unsafe_allow_html=True) | |
#st.markdown(f"[![{title}]({image})]({episode_url})") | |
# st.markdown(f'<a href="{episode_url}">' | |
# f'<img src={image} ' | |
# f'caption={title.split("|")[0]} width=200, use_column_width=False />' | |
# f'</a>', | |
# unsafe_allow_html=True) | |
st.image(image, caption=title.split('|')[0], width=200, use_column_width=False) | |
# let's use all width for the content | |
st.write(hit['content']) | |
def get_answer(query, valid_response, GPTllm): | |
# generate LLM prompt | |
prompt = generate_prompt_series(query=query, | |
results=valid_response) | |
return GPTllm.get_chat_completion(prompt=prompt, | |
system_message='answer this question based on the podcast material', | |
temperature=0, | |
max_tokens=500, | |
stream=False, | |
show_response=False) | |
def reword_query(query, guest, model_name='llama2-13b-chat', response_processing=True): | |
""" Asks LLM to rewrite the query when the guest name is missing. | |
Args: | |
query (str): user query | |
guest (str): guest name | |
model_name (str, optional): name of a LLM model to be used | |
""" | |
# tags = {'llama2-13b-chat': {'start': '<s>', 'end': '</s>', 'instruction': '[INST]', 'system': '[SYS]'}, | |
# 'gpt-3.5-turbo-0613': {'start': '<|startoftext|>', 'end': '', 'instruction': "```", 'system': ```}} | |
prompt_fields = { | |
"you_are":f"You are an expert in linguistics and semantics, analyzing the question asked by a user to a vector search system, \ | |
and making sure that the question is well formulated and that the system can understand it.", | |
"your_task":f"Your task is to detect if the name of the guest ({guest}) is mentioned in the user's question, \ | |
and if that is not the case, rewrite the question using the guest name, \ | |
without changing the meaning of the question. \ | |
Most of the time, the user will have used a pronoun to designate the guest, in which case, \ | |
simply replace the pronoun with the guest name.", | |
"question":f"If the user mentions the guest name, ie {query}, just return his question as is. \ | |
If the user does not mention the guest name, rewrite the question using the guest name.", | |
"final_instruction":f"Only regerate the requested rewritten question or the original, WITHOUT ANY COMMENT OR REPHRASING. \ | |
Your answer must be as close as possible to the original question, \ | |
and exactly identical, word for word, if the user mentions the guest name, i.e. {guest}.", | |
} | |
# prompt created by chatGPT :-) | |
# and Llama still outputs the original question and precedes the answer with 'rewritten question' | |
prompt_fields2 = { | |
"you_are": ( | |
"You are an expert in linguistics and semantics. Your role is to analyze questions asked to a vector search system." | |
), | |
"your_task": ( | |
f"Detect if the guest's FULL name, {guest}, is mentioned in the user's question. " | |
"If not, rewrite the question by replacing pronouns or indirect references with the guest's name." \ | |
"If yes, return the original question as is, without any change at all, not even punctuation," | |
"except a question mark that you MUST add if it's missing." | |
), | |
"question": ( | |
f"Original question: '{query}'. " | |
"Rewrite this question to include the guest's FULL name if it's not already mentioned." | |
"The Only thing you can and MUST add is a question mark if it's missing." | |
), | |
"final_instruction": ( | |
"Create a rewritten question or keep the original question as is. " | |
"Do not include any labels, titles, or additional text before or after the question." | |
"The Only thing you can and MUST add is a question mark if it's missing." | |
"Return a json object, with the key 'original_question' for the original question, \ | |
and 'rewritten_question' for the rewritten question \ | |
and 'changed' being True if you changed the answer, otherwise False." | |
), | |
} | |
if model_name == 'llama2-13b-chat': | |
# special tags are used: | |
# `<s>` - start prompt tag | |
# `[INST], [/INST]` - Opening and closing model instruction tags | |
# `<<<SYS>>>, <</SYS>>` - Opening and closing system prompt tags | |
llama_prompt = """ | |
<s>[INST] <<SYS>> | |
{you_are} | |
<</SYS>> | |
{your_task}\n | |
``` | |
\n\n | |
Question: {question}\n | |
{final_instruction} [/INST] | |
Answer: | |
""" | |
prompt = llama_prompt.format(**prompt_fields2) | |
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris'] | |
# hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN'] | |
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT'] | |
headers = {"Authorization": f"Bearer {hf_token}", | |
"Content-Type": "application/json",} | |
json_body = { | |
"inputs": prompt, | |
"parameters": {"max_new_tokens":400, | |
"repetition_penalty": 1.0, | |
"temperature":0.01} | |
} | |
response = requests.request("POST", hf_endpoint, headers=headers, data=json.dumps(json_body)) | |
response = json.loads(response.content.decode("utf-8")) | |
# ^ will not process the badly formatted generated text, so we do it ourselves | |
if isinstance(response, dict) and 'error' in response: | |
print("Found error") | |
print(response) | |
# return {'error': response['error'], 'rewritten_question': query, 'changed': False, 'status': 'error'} | |
# I test this here otherwise it gets in col 2 or 1, which are too | |
# if reworded_query['status'] == 'error': | |
# st.write(f"Error in LLM response: 'error':{reworded_query['error']}") | |
# st.write("The LLM could not connect to the server. Please try again later.") | |
# st.stop() | |
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613') | |
if response_processing: | |
if isinstance(response, list) and isinstance(response[0], dict) and 'generated_text' in response[0]: | |
print("Found generated text") | |
response0 = response[0]['generated_text'] | |
pattern = r'\"(\w+)\":\s*(\".*?\"|\w+)' | |
matches = re.findall(pattern, response0) | |
# let's build a dictionary | |
result = {key: json.loads(value) if value.startswith("\"") else value for key, value in matches} | |
return result | {'status': 'success'} | |
else: | |
print("Found no answer") | |
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613') | |
# return {'original_question': query, 'rewritten_question': query, 'changed': False, 'status': 'no properly formatted answer' } | |
else: | |
return response | |
# return response | |
# assert 'error' not in response, f"Error in LLM response: {response['error']}" | |
# assert 'generated_text' in response[0], f"Error in LLM response: {response}, no 'generated_text' field" | |
# # let's extract the rewritten question | |
# return response[0]['generated_text'] .split("Rewritten question: '")[-1][:-1] | |
else: | |
# assume openai | |
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613'] | |
model_name = model_ids[1] | |
GPTllm = GPT_Turbo(model=model_name, | |
api_key=st.secrets['secrets']['OPENAI_API_KEY']) | |
openai_prompt = """ | |
{your_task}\n | |
``` | |
\n\n | |
Question: {question}\n | |
{final_instruction} | |
Answer: | |
""" | |
prompt = openai_prompt.format(**prompt_fields) | |
try: | |
resp = GPTllm.get_chat_completion(prompt=openai_prompt, | |
system_message=prompt_fields['you_are'], | |
temperature=0.01, | |
max_tokens=1500, # it's a question... | |
show_response=True, | |
stream=False) | |
return {'rewritten_question': resp.choices[0].delta.content, | |
'changed': True, 'status': 'success'} | |
except Exception: | |
return {'rewritten_question': query, 'changed': False, 'status': 'not success'} | |
if __name__ == '__main__': | |
main() | |
# %% | |