File size: 2,637 Bytes
75128dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import shutil
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever, MultiModalRetriever
from haystack.nodes.reader import FARMReader
from haystack import Pipeline
from utils.config import (INDEX_DIR)
from typing import List
from haystack import BaseComponent, Answer
import streamlit as st



class AnswerToQuery(BaseComponent):

  outgoing_edges = 1

  def run(self, query: str, answers: List[Answer]):
    return {"query": answers[0].answer}, "output_1"

  def run_batch(self):
    raise NotImplementedError()

# cached to make index and models load only at start
@st.cache(
    hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
)
def start_haystack():
    """
    load document store, retriever, entailment checker and create pipeline
    """
    shutil.copy(f"{INDEX_DIR}/text.db", ".")
    shutil.copy(f"{INDEX_DIR}/images.db", ".")

    document_store_text = FAISSDocumentStore(
        faiss_index_path=f"{INDEX_DIR}/text.faiss",
        faiss_config_path=f"{INDEX_DIR}/text.json",
    )
    
    document_store_images = FAISSDocumentStore(
        faiss_index_path=f"{INDEX_DIR}/images.faiss",
        faiss_config_path=f"{INDEX_DIR}/images.json",
    )
    retriever_text = EmbeddingRetriever(
        document_store=document_store_text,
        embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
        model_format="sentence_transformers",
    )

    reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
    

    retriever_images = MultiModalRetriever(
        document_store=document_store_images,
        query_embedding_model = "sentence-transformers/clip-ViT-B-32",
        query_type="text",
        document_embedding_models = {
            "image": "sentence-transformers/clip-ViT-B-32"
        }
    )
    
    answer_to_query = AnswerToQuery()

    pipe = Pipeline()

    pipe.add_node(retriever_text, name="text_retriever", inputs=["Query"])
    pipe.add_node(reader, name="text_reader", inputs=["text_retriever"])
    pipe.add_node(answer_to_query, name="answer2query", inputs=["text_reader"])
    pipe.add_node(retriever_images, name="image_retriever", inputs=["answer2query"])

    return pipe
    
pipe = start_haystack()

@st.cache(allow_output_mutation=True)
def query(statement: str, text_retriever_top_k: int = 5, image_retriever_top_k = 1):
    """Run query and verify statement"""
    params = {"image_retriever": {"top_k": image_retriever_top_k},"text_retriever": {"top_k": text_retriever_top_k} }
    results = pipe.run(statement, params=params)
    return results