rag-test-venkat / Index.py
DeepVen's picture
Upload 2 files
8282916
raw
history blame
4.36 kB
from fastapi import FastAPI
import os
import phoenix as px
from phoenix.trace.langchain import OpenInferenceTracer, LangChainInstrumentor
from langchain.embeddings import HuggingFaceEmbeddings #for using HugginFace models
from langchain.chains.question_answering import load_qa_chain
from langchain import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain.callbacks import StdOutCallbackHandler
#from langchain.retrievers import KNNRetriever
from langchain.storage import LocalFileStore
from langchain.embeddings import CacheBackedEmbeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain import HuggingFaceHub
# from langchain.prompts import PromptTemplate
# from langchain.chains import LLMChain
# from txtai.embeddings import Embeddings
# from txtai.pipeline import Extractor
# import pandas as pd
# import sqlite3
# import os
# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")
#phoenix setup
session = px.launch_app()
# If no exporter is specified, the tracer will export to the locally running Phoenix server
tracer = OpenInferenceTracer()
# If no tracer is specified, a tracer is constructed for you
LangChainInstrumentor(tracer).instrument()
print(session.url)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "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"
# embedding cache
store = LocalFileStore("./cache/")
# define embedder
core_embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, store)
# define llm
llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":1000000})
#llm=HuggingFaceHub(repo_id="gpt2", model_kwargs={"temperature":1, "max_length":1000000})
handler = StdOutCallbackHandler()
# set global variable
vectorstore = None
retriever = None
def initialize_vectorstore():
webpage_loader = WebBaseLoader("https://www.tredence.com/case-studies/tredence-helped-a-global-retailer-providing-holistic-campaign-analytics-by-using-the-power-of-gcp").load()
webpage_chunks = _text_splitter(webpage_loader)
global vectorstore
global retriever
# store embeddings in vector store
vectorstore = FAISS.from_documents(webpage_chunks, embedder)
print("vector store initialized with sample doc")
# instantiate a retriever
retriever = vectorstore.as_retriever()
def _text_splitter(doc):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=50,
length_function=len,
)
return text_splitter.transform_documents(doc)
def _load_docs(path: str):
load_doc = WebBaseLoader(path).load()
doc = _text_splitter(load_doc)
return doc
@app.get("/index/")
def get_domain_file_path(file_path: str):
print("file_path " ,file_path)
webpage_loader = _load_docs(file_path)
webpage_chunks = _text_splitter(webpage_loader)
# store embeddings in vector store
vectorstore.add_documents(webpage_chunks)
return "document loaded to vector store successfully!!"
def _prompt(question):
return f"""Answer following question using only the context below. Say 'Could not find answer with provided context' when question can't be answered.
Question: {question}
Context: """
@app.get("/rag")
def rag( question: str):
chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
callbacks=[handler],
return_source_documents=True
)
#response = chain("how tredence brought good insight?")
response = chain(_prompt(question))
return {"question": question, "answer": response['result']}
initialize_vectorstore()
#import getpass
from pyngrok import ngrok, conf
#print("Enter your authtoken, which can be copied from https://dashboard.ngrok.com/auth")
conf.get_default().auth_token = "2WJNWULs5bCOyJnV24WQYJEKod3_YQUbM5EGCp8sgE4aQvzi"
port = 37689
# Open a ngrok tunnel to the HTTP server
public_url = ngrok.connect(port).public_url
print(" * ngrok tunnel \"{}\" -> \"http://127.0.0.1:{}\"".format(public_url, port))