docwhiz / utils /haystack.py
Anirudh Madhigiri Gopinath
pusing
fa2034d
raw
history blame contribute delete
No virus
5.51 kB
import streamlit as st
from utils.config import document_store_configs, model_configs
from haystack import Pipeline
from haystack.schema import Answer
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores import InMemoryDocumentStore, OpenSearchDocumentStore, WeaviateDocumentStore
from haystack.nodes import EmbeddingRetriever, FARMReader, PromptNode, PreProcessor
#from haystack.nodes import TextConverter, FileTypeClassifier, PDFToTextConverter
from milvus_haystack import MilvusDocumentStore
#Use this file to set up your Haystack pipeline and querying
@st.cache_resource(show_spinner=False)
def start_preprocessor_node():
print('initializing preprocessor node')
processor = PreProcessor(
clean_empty_lines= True,
clean_whitespace=True,
clean_header_footer=True,
#remove_substrings=None,
split_by="word",
split_length=100,
split_respect_sentence_boundary=True,
#split_overlap=0,
#max_chars_check= 10_000
)
return processor
#return docs
@st.cache_resource(show_spinner=False)
def start_document_store(type: str):
#This function starts the documents store of your choice based on your command line preference
print('initializing document store')
if type == 'inmemory':
document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=384)
'''
documents = [
{
'content': "Pi is a super dog",
'meta': {'name': "pi.txt"}
},
{
'content': "The revenue of siemens is 5 milion Euro",
'meta': {'name': "siemens.txt"}
},
]
document_store.write_documents(documents)
'''
elif type == 'opensearch':
document_store = OpenSearchDocumentStore(scheme = document_store_configs['OPENSEARCH_SCHEME'],
username = document_store_configs['OPENSEARCH_USERNAME'],
password = document_store_configs['OPENSEARCH_PASSWORD'],
host = document_store_configs['OPENSEARCH_HOST'],
port = document_store_configs['OPENSEARCH_PORT'],
index = document_store_configs['OPENSEARCH_INDEX'],
embedding_dim = document_store_configs['OPENSEARCH_EMBEDDING_DIM'])
elif type == 'weaviate':
document_store = WeaviateDocumentStore(host = document_store_configs['WEAVIATE_HOST'],
port = document_store_configs['WEAVIATE_PORT'],
index = document_store_configs['WEAVIATE_INDEX'],
embedding_dim = document_store_configs['WEAVIATE_EMBEDDING_DIM'])
elif type == 'milvus':
document_store = MilvusDocumentStore(uri = document_store_configs['MILVUS_URI'],
index = document_store_configs['MILVUS_INDEX'],
embedding_dim = document_store_configs['MILVUS_EMBEDDING_DIM'],
return_embedding=True)
return document_store
# cached to make index and models load only at start
@st.cache_resource(show_spinner=False)
def start_retriever(_document_store: BaseDocumentStore):
print('initializing retriever')
retriever = EmbeddingRetriever(document_store=_document_store,
embedding_model=model_configs['EMBEDDING_MODEL'],
top_k=5)
#
#_document_store.update_embeddings(retriever)
return retriever
@st.cache_resource(show_spinner=False)
def start_reader():
print('initializing reader')
reader = FARMReader(model_name_or_path=model_configs['EXTRACTIVE_MODEL'])
return reader
# cached to make index and models load only at start
@st.cache_resource(show_spinner=False)
def start_haystack_extractive(_document_store: BaseDocumentStore, _retriever: EmbeddingRetriever, _reader: FARMReader):
print('initializing pipeline')
pipe = Pipeline()
pipe.add_node(component=_retriever, name="Retriever", inputs=["Query"])
pipe.add_node(component= _reader, name="Reader", inputs=["Retriever"])
return pipe
@st.cache_resource(show_spinner=False)
def start_haystack_rag(_document_store: BaseDocumentStore, _retriever: EmbeddingRetriever, openai_key):
prompt_node = PromptNode(default_prompt_template="deepset/question-answering",
model_name_or_path=model_configs['GENERATIVE_MODEL'],
api_key=openai_key,
max_length=500)
pipe = Pipeline()
pipe.add_node(component=_retriever, name="Retriever", inputs=["Query"])
pipe.add_node(component=prompt_node, name="PromptNode", inputs=["Retriever"])
return pipe
#@st.cache_data(show_spinner=True)
def query(_pipeline, question):
params = {}
results = _pipeline.run(question, params=params)
return results
def initialize_pipeline(task, document_store, retriever, reader, openai_key = ""):
if task == 'extractive':
return start_haystack_extractive(document_store, retriever, reader)
elif task == 'rag':
return start_haystack_rag(document_store, retriever, openai_key)