Spaces:
Running
Running
import streamlit as st | |
from chat_client import chat | |
import time | |
import pandas as pd | |
import os | |
from dotenv import load_dotenv | |
from search_client import SearchClient | |
import math | |
import numpy as np | |
from sentence_transformers import CrossEncoder | |
load_dotenv() | |
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") | |
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") | |
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY") | |
COST_PER_1000_TOKENS_INR = 0.139 | |
CHAT_BOTS = { | |
"Mixtral 8x7B v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"Mistral 7B v0.1": "mistralai/Mistral-7B-Instruct-v0.1", | |
} | |
INITIAL_PROMPT_ENGINEERING = { | |
"SYSTEM_INSTRUCTION": """ You are a knowledgeable author on medical conditions, with a deep expertise in Huntington's disease. | |
You provide extensive, clear information on complex medical topics, treatments, new research and developments. | |
You avoid giving personal medical advice or diagnoses but offers general advice and underscores the importance of consulting healthcare professionals. | |
Your goal is to inform, engage and enlighten users that enquire about Huntington's disease, offering factual data and real-life perspectives with anempathetic tone. | |
You use every search available including web search together with articles and information from | |
* Journal of Huntington's disease, | |
* Movement Disorders, | |
* Neurology, | |
* Journal of Neurology, | |
* Neurosurgery & Psychiatry, | |
* HDBuzz, | |
* PubMed, | |
* Huntington's disease Society of America (HDSA), | |
* Huntington Study Group (HSG), | |
* Nature Reviews Neurology | |
* ScienceDirect | |
The information you provide should be understandable to laypersons, well-organized, and include credible sources, citations, and an empathetic tone. | |
It should educate on the scientific aspects and personal challenges of living with Huntington's Disease.""", | |
"SYSTEM_RESPONSE": """Hello! I'm an assistant trained to provide detailed and accurate information on medical conditions, including Huntington's Disease. | |
I'm here to help answer your questions and provide resources to help you better understand this disease and its impact on individuals and their families. | |
If you have any questions about HD or related topics, feel free to ask!""", | |
"PRE_CONTEXT": """NOW YOU ARE SEARCHING THE WEB, AND HERE ARE THE CHUNKS RETRIEVED FROM THE WEB.""", | |
"POST_CONTEXT": """ """, # EMPTY | |
"PRE_PROMPT": """NOW ACCORDING TO THE CONTEXT RETRIEVED FROM THE GENERATE THE CONTENT FOR THE FOLLOWING SUBJECT""", | |
"POST_PROMPT": """ | |
Do not repeat yourself | |
""", | |
} | |
googleSearchClient = SearchClient( | |
"google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID | |
) | |
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None) | |
st.set_page_config( | |
page_title="Mixtral Playground", | |
page_icon="📚", | |
) | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
def rerank(query, top_k, search_results): | |
chunks = [] | |
for result in search_results: | |
text = result["text"] | |
# Chunk the text into segments of 512 words each | |
words = text.split() | |
chunk_size = 512 | |
num_chunks = math.ceil(len(words) / chunk_size) | |
for i in range(num_chunks): | |
start = i * chunk_size | |
end = (i + 1) * chunk_size | |
chunk = " ".join(words[start:end]) | |
chunks.append((result["link"], chunk)) | |
# Create sentence combinations with the query | |
sentence_combinations = [[query, chunk[1]] for chunk in chunks] | |
# Compute similarity scores for these combinations | |
similarity_scores = reranker.predict(sentence_combinations) | |
# Sort scores in decreasing order | |
sim_scores_argsort = reversed(np.argsort(similarity_scores)) | |
# Rearrange search_results based on the reranked scores | |
reranked_results = [] | |
for idx in sim_scores_argsort: | |
link = chunks[idx][0] | |
for result in search_results: | |
if result["link"] == link: | |
reranked_results.append(result) | |
break | |
return reranked_results[:top_k] | |
def gen_augmented_prompt_via_websearch( | |
prompt, | |
vendor, | |
n_crawl, | |
top_k, | |
pre_context, | |
post_context, | |
pre_prompt="", | |
post_prompt="", | |
pass_prev=False, | |
): | |
"""returns a prompt with the context of the query and the top k web search results. | |
Args: | |
query (_type_): _description_ | |
top_k (_type_): _description_ | |
preprompt (str, optional): _description_. Defaults to "". | |
postprompt (str, optional): _description_. Defaults to "". | |
""" | |
search_results = [] | |
if vendor == "Google": | |
search_results = googleSearchClient.search(prompt, n_crawl) | |
elif vendor == "Bing": | |
search_results = bingSearchClient.search(prompt, n_crawl) | |
reranked_results = rerank(prompt, top_k, search_results) | |
links = [] | |
context = "" | |
for res in reranked_results: | |
context += res["text"] + "\n\n" | |
link = res["link"] | |
links.append(link) | |
print(reranked_results) | |
prev_input = st.session_state.history[-1][1] if pass_prev else "" | |
generated_prompt = f""" | |
{pre_context} | |
{context} | |
{post_context} | |
{pre_prompt} | |
{prompt} \n\n | |
{post_prompt} | |
{prev_input} | |
""" | |
return generated_prompt, links | |
def init_state(): | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "tokens_used" not in st.session_state: | |
st.session_state.tokens_used = 0 | |
if "tps" not in st.session_state: | |
st.session_state.tps = 0 | |
if "temp" not in st.session_state: | |
st.session_state.temp = 0.8 | |
if "history" not in st.session_state: | |
st.session_state.history = [ | |
[ | |
INITIAL_PROMPT_ENGINEERING["SYSTEM_INSTRUCTION"], | |
INITIAL_PROMPT_ENGINEERING["SYSTEM_RESPONSE"], | |
] | |
] | |
if "n_crawl" not in st.session_state: | |
st.session_state.n_crawl = 5 | |
if "repetion_penalty" not in st.session_state: | |
st.session_state.repetion_penalty = 1 | |
if "rag_enabled" not in st.session_state: | |
st.session_state.rag_enabled = True | |
if "chat_bot" not in st.session_state: | |
st.session_state.chat_bot = "Mixtral 8x7B v0.1" | |
if "search_vendor" not in st.session_state: | |
st.session_state.search_vendor = "Bing" | |
if "system_instruction" not in st.session_state: | |
st.session_state.system_instruction = INITIAL_PROMPT_ENGINEERING[ | |
"SYSTEM_INSTRUCTION" | |
] | |
if "system_response" not in st.session_state: | |
st.session_state.system_instruction = INITIAL_PROMPT_ENGINEERING[ | |
"SYSTEM_RESPONSE" | |
] | |
if "pre_context" not in st.session_state: | |
st.session_state.pre_context = INITIAL_PROMPT_ENGINEERING["PRE_CONTEXT"] | |
if "post_context" not in st.session_state: | |
st.session_state.post_context = INITIAL_PROMPT_ENGINEERING["POST_CONTEXT"] | |
if "pre_prompt" not in st.session_state: | |
st.session_state.pre_prompt = INITIAL_PROMPT_ENGINEERING["PRE_PROMPT"] | |
if "post_prompt" not in st.session_state: | |
st.session_state.post_prompt = INITIAL_PROMPT_ENGINEERING["POST_PROMPT"] | |
if "pass_prev" not in st.session_state: | |
st.session_state.pass_prev = False | |
def sidebar(): | |
def retrieval_settings(): | |
st.markdown("# Web Retrieval") | |
st.session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True) | |
st.session_state.search_vendor = st.radio( | |
"Select Search Vendor", | |
["Bing", "Google"], | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.session_state.n_crawl = st.slider( | |
label="Links to Crawl", | |
key=1, | |
min_value=1, | |
max_value=10, | |
value=4, | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.session_state.top_k = st.slider( | |
label="Rerank Factor", | |
key=2, | |
min_value=1, | |
max_value=20, | |
value=4, | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.markdown("---") | |
def model_analytics(): | |
st.markdown("# Model Analytics") | |
st.write("Total tokens used :", st.session_state["tokens_used"]) | |
st.write("Speed :", st.session_state["tps"], " tokens/sec") | |
st.write( | |
"Total cost incurred :", | |
round( | |
COST_PER_1000_TOKENS_INR * 80 * st.session_state["tokens_used"] / 1000, | |
3, | |
), | |
"INR", | |
) | |
st.markdown("---") | |
def model_settings(): | |
st.markdown("# Model Settings") | |
st.session_state.chat_bot = st.sidebar.radio( | |
"Select one:", [key for key, _ in CHAT_BOTS.items()] | |
) | |
st.session_state.temp = st.slider( | |
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9 | |
) | |
st.session_state.max_tokens = st.slider( | |
label="New tokens to generate", | |
min_value=64, | |
max_value=2048, | |
step=32, | |
value=512, | |
) | |
st.session_state.repetion_penalty = st.slider( | |
label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0 | |
) | |
with st.sidebar: | |
retrieval_settings() | |
model_analytics() | |
model_settings() | |
st.markdown( | |
""" | |
> **Created by [Pragnesh Barik](https://barik.super.site) 🔗** | |
""" | |
) | |
def prompt_engineering_dashboard(): | |
def engineer_prompt(): | |
st.session_state.history[0] = [ | |
st.session_state.system_instruction, | |
st.session_state.system_response, | |
] | |
with st.expander("Prompt Engineering Dashboard"): | |
st.info( | |
"**The input to the model follows this below template**", | |
) | |
st.code( | |
""" | |
[SYSTEM INSTRUCTION] | |
[SYSTEM RESPONSE] | |
[... LIST OF PREV INPUTS] | |
[PRE CONTEXT] | |
[CONTEXT RETRIEVED FROM THE WEB] | |
[POST CONTEXT] | |
[PRE PROMPT] | |
[PROMPT] | |
[POST PROMPT] | |
[PREV GENERATED INPUT] # Only if Pass previous prompt set True | |
""" | |
) | |
st.session_state.system_instruction = st.text_area( | |
label="SYSTEM INSTRUCTION", | |
value=INITIAL_PROMPT_ENGINEERING["SYSTEM_INSTRUCTION"], | |
) | |
st.session_state.system_response = st.text_area( | |
"SYSTEM RESPONSE", value=INITIAL_PROMPT_ENGINEERING["SYSTEM_RESPONSE"] | |
) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.text_input( | |
"PRE CONTEXT", | |
value=INITIAL_PROMPT_ENGINEERING["PRE_CONTEXT"], | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.text_input("PRE PROMPT", value=INITIAL_PROMPT_ENGINEERING["PRE_PROMPT"]) | |
st.button("Engineer Prompts", on_click=engineer_prompt) | |
with col2: | |
st.text_input( | |
"POST CONTEXT", | |
value=INITIAL_PROMPT_ENGINEERING["POST_CONTEXT"], | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.text_input( | |
"POST PROMPT", value=INITIAL_PROMPT_ENGINEERING["POST_PROMPT"] | |
) | |
pass_prev = st.toggle("Pass previous prompt") | |
def header(): | |
st.write("# Mixtral Playground") | |
data = { | |
"Attribute": ["LLM", "Text Vectorizer", "Vector Database", "CPU", "System RAM"], | |
"Information": [ | |
"Mixtral-8x7B-Instruct-v0.1", | |
"all-distilroberta-v1", | |
"Hosted Pinecone", | |
"2 vCPU", | |
"16 GB", | |
], | |
} | |
df = pd.DataFrame(data) | |
st.table(df) | |
prompt_engineering_dashboard() | |
def chat_box(): | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
def generate_chat_stream(prompt): | |
links = [] | |
if st.session_state.rag_enabled: | |
with st.spinner("Fetching relevent documents from Web...."): | |
prompt, links = gen_augmented_prompt_via_websearch( | |
prompt=prompt, | |
pre_context=st.session_state.pre_context, | |
post_context=st.session_state.post_context, | |
pre_prompt=st.session_state.pre_prompt, | |
post_prompt=st.session_state.post_prompt, | |
vendor=st.session_state.search_vendor, | |
top_k=st.session_state.top_k, | |
n_crawl=st.session_state.n_crawl, | |
) | |
with st.spinner("Generating response..."): | |
chat_stream = chat( | |
prompt, | |
st.session_state.history, | |
chat_client=CHAT_BOTS[st.session_state.chat_bot], | |
temperature=st.session_state.temp, | |
max_new_tokens=st.session_state.max_tokens, | |
) | |
return chat_stream, links | |
def stream_handler(chat_stream, placeholder): | |
start_time = time.time() | |
full_response = "" | |
for chunk in chat_stream: | |
if chunk.token.text != "</s>": | |
full_response += chunk.token.text | |
placeholder.markdown(full_response + "▌") | |
placeholder.markdown(full_response) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
total_tokens_processed = len(full_response.split()) | |
tokens_per_second = total_tokens_processed // elapsed_time | |
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.write(f"**{tokens_per_second} tokens/second**") | |
with col2: | |
st.write(f"**{int(len_response)} tokens generated**") | |
with col3: | |
st.write( | |
f"**$ {round(len_response * COST_PER_1000_TOKENS_INR * 82 / 1000, 5)} cost incurred**" | |
) | |
st.session_state["tps"] = tokens_per_second | |
st.session_state["tokens_used"] = len_response + st.session_state["tokens_used"] | |
return full_response | |
def show_source(links): | |
with st.expander("Show source"): | |
for i, link in enumerate(links): | |
st.info(f"{link}") | |
init_state() | |
sidebar() | |
header() | |
chat_box() | |
if prompt := st.chat_input("Generate Ebook"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
chat_stream, links = generate_chat_stream(prompt) | |
with st.chat_message("assistant"): | |
placeholder = st.empty() | |
full_response = stream_handler(chat_stream, placeholder) | |
if st.session_state.rag_enabled: | |
show_source(links) | |
st.session_state.history.append([prompt, full_response]) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |