find-the-animal / utils /haystack.py
Tuana's picture
Initial code
75128dd
raw
history blame
2.64 kB
import shutil
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever, MultiModalRetriever
from haystack.nodes.reader import FARMReader
from haystack import Pipeline
from utils.config import (INDEX_DIR)
from typing import List
from haystack import BaseComponent, Answer
import streamlit as st
class AnswerToQuery(BaseComponent):
outgoing_edges = 1
def run(self, query: str, answers: List[Answer]):
return {"query": answers[0].answer}, "output_1"
def run_batch(self):
raise NotImplementedError()
# cached to make index and models load only at start
@st.cache(
hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
)
def start_haystack():
"""
load document store, retriever, entailment checker and create pipeline
"""
shutil.copy(f"{INDEX_DIR}/text.db", ".")
shutil.copy(f"{INDEX_DIR}/images.db", ".")
document_store_text = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/text.faiss",
faiss_config_path=f"{INDEX_DIR}/text.json",
)
document_store_images = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/images.faiss",
faiss_config_path=f"{INDEX_DIR}/images.json",
)
retriever_text = EmbeddingRetriever(
document_store=document_store_text,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
model_format="sentence_transformers",
)
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
retriever_images = MultiModalRetriever(
document_store=document_store_images,
query_embedding_model = "sentence-transformers/clip-ViT-B-32",
query_type="text",
document_embedding_models = {
"image": "sentence-transformers/clip-ViT-B-32"
}
)
answer_to_query = AnswerToQuery()
pipe = Pipeline()
pipe.add_node(retriever_text, name="text_retriever", inputs=["Query"])
pipe.add_node(reader, name="text_reader", inputs=["text_retriever"])
pipe.add_node(answer_to_query, name="answer2query", inputs=["text_reader"])
pipe.add_node(retriever_images, name="image_retriever", inputs=["answer2query"])
return pipe
pipe = start_haystack()
@st.cache(allow_output_mutation=True)
def query(statement: str, text_retriever_top_k: int = 5, image_retriever_top_k = 1):
"""Run query and verify statement"""
params = {"image_retriever": {"top_k": image_retriever_top_k},"text_retriever": {"top_k": text_retriever_top_k} }
results = pipe.run(statement, params=params)
return results