Spaces:
Runtime error
Runtime error
import os | |
import json | |
import pandas as pd | |
import time | |
import phoenix as px | |
from phoenix.trace.langchain import OpenInferenceTracer, LangChainInstrumentor | |
#from hallucinator import HallucinatonEvaluater | |
from langchain.embeddings import HuggingFaceEmbeddings #for using HugginFace models | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain import HuggingFaceHub | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.callbacks import StdOutCallbackHandler | |
#from langchain.retrievers import KNNRetriever | |
from langchain.storage import LocalFileStore | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.document_loaders import WebBaseLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import numpy as np | |
import streamlit as st | |
import pandas as pd | |
# from sklearn import datasets | |
# from sklearn.ensemble import RandomForestClassifier | |
from PIL import Image | |
global trace_df | |
global vectorstore | |
global retriever | |
# Page config | |
st.set_page_config(page_title="RAG PoC", layout="wide") | |
st.sidebar.image(Image.open("./test-logo.png"), use_column_width=True) | |
def tracer_config(): | |
#phoenix setup | |
session = px.launch_app() | |
# If no exporter is specified, the tracer will export to the locally running Phoenix server | |
tracer = OpenInferenceTracer() | |
# If no tracer is specified, a tracer is constructed for you | |
LangChainInstrumentor(tracer).instrument() | |
time.sleep(3) | |
print(session.url) | |
tracer_config() | |
tab1, tab2 = st.tabs(["📈 **RAG**", "🗃 FactVsHallucinate" ]) | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "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" | |
# embedding cache | |
#store = LocalFileStore("./cache/") | |
# define embedder | |
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
#embedder=HuggingFaceHub(repo_id="sentence-transformers/all-mpnet-base-v2") | |
#embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, store) | |
# define llm | |
llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":1000000}) | |
#llm=HuggingFaceHub(repo_id="gpt2", model_kwargs={"temperature":1, "max_length":1000000}) | |
handler = StdOutCallbackHandler() | |
# set global variable | |
# vectorstore = None | |
# retriever = None | |
class HallucinatePromptContext: | |
def __init__(self): | |
self.variables_list = ["query","answer","context"] | |
self.base_template = """In this task, you will be presented with a query, a reference text and an answer. The answer is | |
generated to the question based on the reference text. The answer may contain false information, you | |
must use the reference text to determine if the answer to the question contains false information, | |
if the answer is a hallucination of facts. Your objective is to determine whether the reference text | |
contains factual information and is not a hallucination. A 'hallucination' in this context refers to | |
an answer that is not based on the reference text or assumes information that is not available in | |
the reference text. Your response should be a single word: either "factual" or "hallucinated", and | |
it should not include any other text or characters. "hallucinated" indicates that the answer | |
provides factually inaccurate information to the query based on the reference text. "factual" | |
indicates that the answer to the question is correct relative to the reference text, and does not | |
contain made up information. Please read the query and reference text carefully before determining | |
your response. | |
# Query: {query} | |
# Reference text: {context} | |
# Answer: {answer} | |
Is the answer above factual or hallucinated based on the query and reference text?""" | |
class HallucinatonEvaluater: | |
def __init__(self, item): | |
self.question = item["question"] | |
self.answer = item["answer"] | |
#self.domain = item["domain"] | |
self.context = item["context"] | |
self.llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":1000000}) | |
def get_prompt_template(self): | |
prompt = HallucinatePromptContext() | |
template = prompt.base_template | |
varialbles = prompt.variables_list | |
eval_template = PromptTemplate(input_variables=varialbles, template=template) | |
return eval_template | |
def evaluate(self): | |
prompt = self.get_prompt_template().format(query = self.question, answer = self.answer, context = self.context) | |
score = self.llm(prompt) | |
return score | |
def initialize_vectorstore(): | |
webpage_loader = WebBaseLoader("https://www.tredence.com/case-studies/forecasting-app-installs-for-a-large-retailer-in-the-us").load() | |
webpage_chunks = _text_splitter(webpage_loader) | |
global vectorstore | |
global retriever | |
# store embeddings in vector store | |
vectorstore = FAISS.from_documents(webpage_chunks, embedder) | |
print("vector store initialized with sample doc") | |
# instantiate a retriever | |
retriever = vectorstore.as_retriever() | |
st.session_state['vectorstore'] = vectorstore | |
st.session_state['retriever'] = retriever | |
return retriever | |
def _text_splitter(doc): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=600, | |
chunk_overlap=50, | |
length_function=len, | |
) | |
return text_splitter.transform_documents(doc) | |
def _load_docs(path: str): | |
load_doc = WebBaseLoader(path).load() | |
doc = _text_splitter(load_doc) | |
return doc | |
def rag_response(response): | |
#st.markdown("""<hr style="height:10px;border:none;color:#333;background-color:#333;" /> """, unsafe_allow_html=True) | |
#st.markdown(".stTextInput > label {font-size:105%; font-weight:bold; color:blue;} ",unsafe_allow_html=True) #for all text-input label sections | |
question_title = '<h1 style="color:#33ff33;font-size:24px;">Question</h1>' | |
st.markdown('<h1 style="color:#100170;font-size:48px;text-align:center;">RAG Response</h1>', unsafe_allow_html=True) | |
st.markdown('<h1 style="color:#100170;font-size:24px;">Question</h1>', unsafe_allow_html=True) | |
st.text_area(label="", value=response["query"], height=30) | |
st.markdown('<h1 style="color:#100170;font-size:24px;">RAG Output</h1>', unsafe_allow_html=True) | |
st.text_area(label="", value=response["result"]) | |
st.markdown('<h1 style="color:#100170;font-size:24px;">Augmented knowledge</h1>', unsafe_allow_html=True) | |
st.text_area(label="", value=response["source_documents"]) | |
#st.button("Check Hallucination") | |
# Create extractor instance | |
def _create_hallucination_scenario(item): | |
score = HallucinatonEvaluater(item).evaluate() | |
return score | |
def hallu_eval(question: str, answer: str, context: str): | |
print("in hallu eval") | |
hallucination_score = _create_hallucination_scenario({ | |
"question": question, | |
"answer": answer, | |
"context": context | |
} | |
) | |
print("got hallu score") | |
st.markdown('<h1 style="color:#100170;font-size:24px;">Hallucinated?</h1>', unsafe_allow_html=True) | |
st.text_area(label=" ", value=hallucination_score, height=30) | |
#return {"hallucination_score": hallucination_score} | |
#time.sleep(10) | |
# if 'clicked' not in st.session_state: | |
# print("set state to False") | |
# st.session_state.clicked = False | |
def click_button(response): | |
# print("set state to True") | |
# st.session_state.clicked = True | |
df = st.session_state['trace_df'] | |
print(df.count()) | |
df_sorted = df.sort_values(by='end_time',ascending=False) | |
model_input = json.loads(df_sorted[df_sorted["name"] == "LLMChain"]["attributes.input.value"][0]) | |
context = model_input["context"] | |
print(context) | |
hallu_eval(response["query"], response["result"], context) | |
#st.write(''' # RAG App''') | |
with tab1: | |
with st.form(" RAG with evaluation - scoring & hallucination "): | |
#tab1.subheader(''' # RAG App''') | |
retriever = initialize_vectorstore() | |
#print("lenght in tab1, ", len(st.session_state['vectorstore'].serialize_to_bytes())) | |
options = ["true", "false"] | |
st.markdown('<h1 style="color:#100170;font-size:24px;">User Query</h1>', unsafe_allow_html=True) | |
question = st.text_input(label="", value="", placeholder="Type in question",label_visibility="visible", disabled=False) | |
#st.markdown('<h2 style="color:#3a0aa6;font-size:24px;">Evaluation</h2>', unsafe_allow_html=True) | |
evaluate = st.selectbox(label="***Perform Evaluation?***",options=options, index=1, placeholder="Choose an option", disabled=False, label_visibility="visible") | |
m = st.markdown(""" | |
<style> | |
div.stButton > button:first-child { | |
background-color: #100170; | |
color:#ffffff; | |
} | |
div.stButton > button:hover { | |
background-color: #00ff00; | |
color:#ff0000; | |
} | |
</style>""", unsafe_allow_html=True) | |
#st.markdown("----", unsafe_allow_html=True) | |
columns = st.columns([2,1,2]) | |
if columns[1].form_submit_button(" Start RAG "): | |
st.markdown("""<hr style="height:10px;border:none;color:#333;background-color: #100170;" /> """, unsafe_allow_html=True) | |
print("retrie ,", retriever) | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=retriever, | |
callbacks=[handler], | |
return_source_documents=True | |
) | |
#response = chain("how tredence brought good insight?") | |
response = chain(question) | |
print(response["result"]) | |
rag_response(response) | |
click_button(response) | |
# if st.session_state.clicked: | |
# # The message and nested widget will remain on the page | |
# hallu_eval(response["query"], response["result"], "blah blah") | |
# print("in if for hallu") | |
with tab2: | |
with st.form(" LLM-aasisted evaluation of Hallucination"): | |
#print("lenght in tab2, ", len(vectorstore.serialize_to_bytes())) | |
question = st.text_input(label="**Question**", value="", label_visibility="visible", disabled=False) | |
answer = st.text_input(label="**answer**", value="", label_visibility="visible", disabled=False) | |
context = st.text_input(label="**context**", value="", label_visibility="visible", disabled=False) | |
if st.form_submit_button("Evaluate"): | |
hallu_eval(question, answer, context) | |
print("activ session: ", px.active_session().get_spans_dataframe()) | |
trace_df = px.active_session().get_spans_dataframe() | |
st.session_state['trace_df'] = trace_df | |
# with tab3: | |
# with st.form(" trace"): | |
# if px.active_session(): | |
# df0 = px.active_session().get_spans_dataframe() | |
# if not df0.empty: | |
# df= df0.fillna('') | |
# st.dataframe(df) | |
def rag(): | |
print("in rag") | |
options = ["true", "false"] | |
question = st.text_input(label="user question", value="", label_visibility="visible", disabled=False) | |
evaluate = st.selectbox(label="select evaluation",options=options, index=0, placeholder="Choose an option", disabled=False, label_visibility="visible") | |
if st.button("do RAG"): | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=retriever, | |
callbacks=[handler], | |
return_source_documents=True | |
) | |
#response = chain("how tredence brought good insight?") | |
response = chain(question) | |
print(response["result"]) | |
# time.sleep(4) | |
# df = px.active_session().get_spans_dataframe() | |
# print(px.active_session()) | |
# print(px.active_session().get_spans_dataframe()) | |
# print(df.count()) | |
# df_sorted = df.sort_values(by='end_time',ascending=False) | |
# model_input = json.loads(df_sorted[df_sorted["name"] == "LLMChain"]["attributes.input.value"][0]) | |
# context = model_input["context"] | |
# print(context) | |
# if evaluate: | |
# score = _create_evaluation_scenario({ | |
# "question": question, | |
# "answer": response['result'], | |
# "context": context | |
# }) | |
# else: | |
# score = "Evaluation is Turned OFF" | |
# return {"question": question, "answer": response['result'], "context": context, "score": score} | |
rag_response(response) | |
# if st.button("click me"): | |
# click_button(response) | |
click = st.button("Do you want to see more?") | |
if click: | |
st.session_state.more_stuff = True | |
if st.session_state.more_stuff: | |
click_button(response) | |
#st.write("Doing more optional stuff") | |
return(response) | |
a = st.markdown(""" | |
<style> | |
div.stTextArea > textarea { | |
background-color: #0099ff; | |
height: 1400px; | |
width: 800px; | |
} | |
</style>""", unsafe_allow_html=True) |