from abc import ABC, abstractmethod from haystack.nodes import BM25Retriever, FARMReader from haystack.document_stores import ElasticsearchDocumentStore from haystack.pipelines import ExtractiveQAPipeline from haystack.document_stores import PineconeDocumentStore from haystack.nodes import EmbeddingRetriever import certifi import datetime import requests from base64 import b64encode ca_certs=certifi.where() class DocumentQueries(ABC): @abstractmethod def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str): pass class PinecodeProposalQueries(DocumentQueries): def __init__(self, es_host: str, es_index: str, es_user, es_password, reader_name_or_path: str, use_gpu = True) -> None: reader = FARMReader(model_name_or_path = reader_name_or_path, use_gpu = use_gpu, num_processes=1, context_window_size=200) self._initialize_pipeline(es_host, es_index, es_user, es_password, reader = reader) #self.log = Log(es_host= es_host, es_index="log", es_user = es_user, es_password= es_password) def _initialize_pipeline(self, es_host, es_index, es_user, es_password, reader = None): if reader is not None: self.reader = reader self.es_host = es_host self.es_user = es_user self.es_password = es_password self.document_store = PineconeDocumentStore( api_key=es_password, environment = "us-east1-gcp", index=es_index, similarity="cosine", embedding_dim=768 ) #self.retriever = BM25Retriever(document_store = self.document_store) self.retriever = EmbeddingRetriever( document_store=self.document_store, embedding_model="multi-qa-distilbert-dot-v1", model_format="sentence_transformers" ) self.document_store.update_embeddings(self.retriever, batch_size=16) self.pipe = ExtractiveQAPipeline(self.reader, self.retriever) def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str = None) : #self.log.write_log(query, "hfspace-informecomision") #if es_index is not None: #self._initialize_pipeline(self.es_host, es_index, self.es_user, self.es_password) #params = {"Retriever": {"top_k": retriever_top_k}, "Reader": {"top_k": reader_top_k}} params = {"Retriever": {"top_k": retriever_top_k}} prediction = self.pipe.run( query = query, params = params) return prediction["answers"] class Log(): def __init__(self, es_host: str, es_index: str, es_user, es_password) -> None: self.elastic_endpoint = f"https://{es_host}:443/{es_index}/_doc" self.credentials = b64encode(b"3pvrzh9tl:4yl4vk9ijr").decode("ascii") self.auth_header = { 'Authorization' : 'Basic %s' % self.credentials } def write_log(self, message: str, source: str) -> None: created_date = datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ') post_data = { "message" : message, "createdDate": { "date" : created_date }, "source": source } r = requests.post(self.elastic_endpoint, json = post_data, headers = self.auth_header) print(r.text)