import os |
import json |
import pandas as pd |
import time |
import phoenix as px |
from phoenix.trace.langchain import OpenInferenceTracer, LangChainInstrumentor |
from langchain.embeddings import HuggingFaceEmbeddings |
from langchain.chains.question_answering import load_qa_chain |
from langchain import HuggingFaceHub |
from langchain.prompts import PromptTemplate |
from langchain.chains import RetrievalQA |
from langchain.callbacks import StdOutCallbackHandler |
from langchain.storage import LocalFileStore |
from langchain.embeddings import CacheBackedEmbeddings |
from langchain.vectorstores import FAISS |
from langchain.document_loaders import WebBaseLoader |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
import numpy as np |
import streamlit as st |
import pandas as pd |
global trace_df |
@st.cache_resource |
def tracer_config(): |
session = px.launch_app() |
tracer = OpenInferenceTracer() |
LangChainInstrumentor(tracer).instrument() |
print(session.url) |
tracer_config() |
tab1, tab2 = st.tabs(["π RAG", "π FactVsHallucinate" ]) |
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":1000000}) |
handler = StdOutCallbackHandler() |
class HallucinatePromptContext: |
def __init__(self): |
self.variables_list = ["query","answer","context"] |
self.base_template = """In this task, you will be presented with a query, a reference text and an answer. The answer is |
generated to the question based on the reference text. The answer may contain false information, you |
must use the reference text to determine if the answer to the question contains false information, |
if the answer is a hallucination of facts. Your objective is to determine whether the reference text |
contains factual information and is not a hallucination. A 'hallucination' in this context refers to |
an answer that is not based on the reference text or assumes information that is not available in |
the reference text. Your response should be a single word: either "factual" or "hallucinated", and |
it should not include any other text or characters. "hallucinated" indicates that the answer |
provides factually inaccurate information to the query based on the reference text. "factual" |
indicates that the answer to the question is correct relative to the reference text, and does not |
contain made up information. Please read the query and reference text carefully before determining |
your response. |
# Query: {query} |
# Reference text: {context} |
# Answer: {answer} |
Is the answer above factual or hallucinated based on the query and reference text?""" |
class HallucinatonEvaluater: |
def __init__(self, item): |
self.question = item["question"] |
self.answer = item["answer"] |
self.context = item["context"] |
self.llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":1000000}) |
def get_prompt_template(self): |
prompt = HallucinatePromptContext() |
template = prompt.base_template |
varialbles = prompt.variables_list |
eval_template = PromptTemplate(input_variables=varialbles, template=template) |
return eval_template |
def evaluate(self): |
prompt = self.get_prompt_template().format(query = self.question, answer = self.answer, context = self.context) |
score = self.llm(prompt) |
return score |
@st.cache_resource |
def initialize_vectorstore(): |
webpage_loader = WebBaseLoader("https://www.tredence.com/case-studies/forecasting-app-installs-for-a-large-retailer-in-the-us").load() |
webpage_chunks = _text_splitter(webpage_loader) |
global vectorstore |
global retriever |
vectorstore = FAISS.from_documents(webpage_chunks, embedder) |
print("vector store initialized with sample doc") |
retriever = vectorstore.as_retriever() |
return retriever |
def _text_splitter(doc): |
text_splitter = RecursiveCharacterTextSplitter( |
chunk_size=600, |
chunk_overlap=50, |
length_function=len, |
) |
return text_splitter.transform_documents(doc) |
def _load_docs(path: str): |
load_doc = WebBaseLoader(path).load() |
doc = _text_splitter(load_doc) |
return doc |
def rag_response(response): |
st.markdown("""<hr style="height:10px;border:none;color:#333;background-color:#333;" /> """, unsafe_allow_html=True) |
st.subheader('RAG response') |
st.text_area(label="user query", value=response["query"], height=30) |
st.text_area(label="RAG output", value=response["result"]) |
st.text_area(label="Augmented knowledge", value=response["source_documents"]) |
def _create_hallucination_scenario(item): |
score = HallucinatonEvaluater(item).evaluate() |
return score |
def hallu_eval(question: str, answer: str, context: str): |
print("in hallu eval") |
hallucination_score = _create_hallucination_scenario({ |
"question": question, |
"answer": answer, |
"context": context |
} |
) |
print("got hallu score") |
st.text_area(label="Hallucinated?", value=hallucination_score, height=30) |
def click_button(response): |
hallu_eval(response["query"], response["result"], "blah blah") |
with tab1: |
with st.form(" RAG with evaluation - scoring & hallucination "): |
retriever = initialize_vectorstore() |
options = ["true", "false"] |
question = st.text_input(label="user question", value="", label_visibility="visible", disabled=False) |
evaluate = st.selectbox(label="Evaluation",options=options, index=0, placeholder="Choose an option", disabled=False, label_visibility="visible") |
if st.form_submit_button("RAG with evaluation"): |
print("retrie ,", retriever) |
chain = RetrievalQA.from_chain_type( |
llm=llm, |
retriever=retriever, |
callbacks=[handler], |
return_source_documents=True |
) |
response = chain(question) |
print(response["result"]) |
rag_response(response) |
click_button(response) |
with tab2: |
with st.form(" LLM-aasisted evaluation of Hallucination"): |
question = st.text_input(label="question", value="", label_visibility="visible", disabled=False) |
answer = st.text_input(label="answer", value="", label_visibility="visible", disabled=False) |
context = st.text_input(label="context", value="", label_visibility="visible", disabled=False) |
if st.form_submit_button("Evaluate"): |
hallu_eval(question, answer, context) |
print("activ session: ", px.active_session().get_spans_dataframe()) |
trace_df = px.active_session().get_spans_dataframe() |
st.session_state['trace_df'] = trace_df |
def rag(): |
print("in rag") |
options = ["true", "false"] |
question = st.text_input(label="user question", value="", label_visibility="visible", disabled=False) |
evaluate = st.selectbox(label="select evaluation",options=options, index=0, placeholder="Choose an option", disabled=False, label_visibility="visible") |
if st.button("do RAG"): |
chain = RetrievalQA.from_chain_type( |
llm=llm, |
retriever=retriever, |
callbacks=[handler], |
return_source_documents=True |
) |
response = chain(question) |
print(response["result"]) |
rag_response(response) |
click = st.button("Do you want to see more?") |
if click: |
st.session_state.more_stuff = True |
if st.session_state.more_stuff: |
click_button(response) |
return(response) |