Spaces:
Runtime error
Runtime error
File size: 4,030 Bytes
8df121c a831acd 70b2fc9 05feb2b 70b2fc9 05feb2b f98593f 617bb16 70b2fc9 617bb16 fecdd86 8df121c 617bb16 f98593f 617bb16 fecdd86 0b1473d 617bb16 e162fdc fecdd86 e162fdc 617bb16 fecdd86 05feb2b 617bb16 05feb2b 0b1473d e6afd23 0b1473d d727ca7 0b1473d d727ca7 8df121c 05feb2b 8df121c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import glob
import os
import logging
import sys
import streamlit as st
from haystack import Pipeline
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel
from haystack.nodes.retriever.web import WebRetriever
from haystack.schema import Document
logging.basicConfig(
level=logging.DEBUG,
format="%(levelname)s %(asctime)s %(name)s:%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
force=True,
)
def get_plain_pipeline():
prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
# Now let make one PromptNode use the default model and the other one the OpenAI model:
plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
pipeline = Pipeline()
pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
return pipeline
def get_ret_aug_pipeline():
ds = FAISSDocumentStore(faiss_index_path="my_faiss_index.faiss",
faiss_config_path="my_faiss_index.json")
retriever = EmbeddingRetriever(
document_store=ds,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
model_format="sentence_transformers",
top_k=2
)
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
default_template= PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: $documents; Question: "
"$query; Answer:",
)
# Let's initiate the PromptNode
node = PromptNode("text-davinci-003", default_prompt_template=default_template,
api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
# Let's create a pipeline with Shaper and PromptNode
pipe = Pipeline()
pipe.add_node(component=retriever, name='retriever', inputs=['Query'])
pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
return pipe
def get_web_ret_pipeline():
search_key = st.secrets["WEBRET_API_KEY"]
web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
default_template = PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: $documents; Question: "
"$query; Answer:",
)
# Let's initiate the PromptNode
node = PromptNode("text-davinci-003", default_prompt_template=default_template,
api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
# Let's create a pipeline with Shaper and PromptNode
pipe = Pipeline()
pipe.add_node(component=web_retriever, name='retriever', inputs=['Query'])
pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
return pipe
def app_init():
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
p1 = get_plain_pipeline()
p2 = get_ret_aug_pipeline()
p3 = get_web_ret_pipeline()
return p1, p2, p3
def main():
p1, p2 = app_init()
st.title("Haystack Demo")
input = st.text_input("Query ...")
query_type = st.radio("Type",
("Retrieval Augmented", "Retrieval Augmented with Web Search"))
col_1, col_2 = st.columns(2)
with col_1:
st.text("PLAIN")
answers = p1.run(input)
st.text(answers['results'][0])
with col_2:
st.write(query_type.upper())
answers_2 = p2.run(input)
st.text(answers_2['results'][0])
if __name__ == "__main__":
main()
|