Spaces:
Runtime error
Runtime error
from haystack.nodes.retriever import EmbeddingRetriever | |
from haystack.nodes import TableReader, FARMReader, RouteDocuments, JoinAnswers | |
from haystack import Pipeline | |
text_reader_types = { | |
"minilm": "deepset/minilm-uncased-squad2", | |
"distilroberta": "deepset/tinyroberta-squad2", | |
"electra-base": "deepset/electra-base-squad2", | |
"bert-base": "deepset/bert-base-cased-squad2", | |
"deberta-large": "deepset/deberta-v3-large-squad2", | |
"gpt3": "implement openai answer generator" | |
} | |
table_reader_types = { | |
"tapas": "deepset/tapas-large-nq-hn-reader", | |
"text": "implement changing tables to text" | |
} | |
def create_retriever(document_store): | |
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/all-mpnet-base-v2-table") | |
document_store.update_embeddings(retriever=retriever) | |
return document_store, retriever | |
def create_readers_and_pipeline(retriever, text_reader_type = "deepset/roberta-base-squad2", table_reader_type="deepset/tapas-large-nq-hn-reader", use_table=True, use_text=True): | |
both = (use_table and use_text) | |
if use_text or both: | |
print("Initializing Text reader..") | |
text_reader = FARMReader(text_reader_type) | |
if use_table or both: | |
print("Initializing table reader..") | |
table_reader = TableReader(table_reader_type) | |
if both: | |
route_documents = RouteDocuments() | |
join_answers = JoinAnswers() | |
text_table_qa_pipeline = Pipeline() | |
text_table_qa_pipeline.add_node(component=retriever, name="EmbeddingRetriever", inputs=["Query"]) | |
if use_table and not use_text: | |
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["EmbeddingRetriever"]) | |
elif use_text and not use_table: | |
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["EmbeddingRetriever"]) | |
elif both: | |
text_table_qa_pipeline.add_node(component=route_documents, name="RouteDocuments", inputs=["EmbeddingRetriever"]) | |
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["RouteDocuments.output_1"]) | |
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["RouteDocuments.output_2"]) | |
text_table_qa_pipeline.add_node(component=join_answers, name="JoinAnswers", inputs=["TextReader", "TableReader"]) | |
return text_table_qa_pipeline | |