insuranec-chat / app.py
ethanrom's picture
Update app.py
e511c40
import os
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from huggingface_hub import InferenceClient
import gradio as gr
HF_token = os.getenv("apiToken")
embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.load_local("faiss_index", embeddings)
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":2})
chat_client = InferenceClient(
model="HuggingFaceH4/zephyr-7b-alpha",
token=HF_token
)
reform_client = InferenceClient(
model="mistralai/Mistral-7B-Instruct-v0.1",
token=HF_token
)
def format_prompt(message, history):
docs = retriever.get_relevant_documents(message)
context = "\n".join([doc.page_content for doc in docs])
system = "<|system|>\nYou are a helpful virtual assistant for People's Insurance PLC that answer user's questions using website content.</s>\n"
prompt = ""
for user_prompt, bot_response in history:
prompt += f"<|user|>\n{user_prompt}</s>\n"
prompt += f"<|assistant|>\n{bot_response}</s>\n"
prompt += f"{system}\n website content:{context}\n<|user|>\n {message}</s>\n<|assistant|>\n"
return prompt
def query_reform(message, history):
previous_user_input = ""
previous_ai_response = ""
for i in range(len(history) - 1, -1, -1):
if previous_user_input == "" and history[i][0] != "":
previous_user_input = history[i][0]
if previous_ai_response == "" and history[i][1] != "":
previous_ai_response = history[i][1]
if previous_user_input != "" and previous_ai_response != "":
break
print("Original Question:", message)
print("Last interaction:\nuser:", previous_user_input, "\nAI:", previous_ai_response)
new_prompt = f"Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST]. Chat History:\nUser: {previous_user_input}\nAI: {previous_ai_response}\nFollow-up User message: {message}\nRewritten User message:"
return new_prompt
def reformulate_query(question, history):
reformulated_query = reform_client.text_generation(
query_reform(question, history),
temperature=0.1,
max_new_tokens=50,
top_p=0.9,
repetition_penalty=1.0
)
print("Reformulated Query:", reformulated_query)
return reformulated_query
def generate(
prompt, history, temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
reformed_query = reformulate_query(prompt, history)
docs = retriever.get_relevant_documents(reformed_query)
sources = "\nSources:\n"
for doc in docs:
if 'source' in doc.metadata:
sources += doc.metadata['source'] + '\n'
formatted_prompt = format_prompt(reformed_query, history)
output = ''
stream = chat_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
output += response.token.text
yield output
output += sources
yield output
return output
demo = gr.ChatInterface(generate,
title="People's Insurance PLC chatbot",
theme="Monochrome",
examples=[["What is the contact number of Peoples insurance"], ["Whata re the availbale plans"]],)
demo.queue().launch(debug=True)