Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from haystack.nodes import BM25Retriever, FARMReader | |
from haystack.document_stores import ElasticsearchDocumentStore | |
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline | |
from haystack.document_stores import PineconeDocumentStore | |
from haystack.nodes import EmbeddingRetriever | |
import json | |
import logging | |
import os | |
import shutil | |
import sys | |
import uuid | |
from json import JSONDecodeError | |
from pathlib import Path | |
from typing import List, Optional | |
import pandas as pd | |
import pinecone | |
import streamlit as st | |
from haystack import BaseComponent, Document | |
from haystack.document_stores import PineconeDocumentStore | |
from haystack.nodes import ( | |
EmbeddingRetriever, | |
FARMReader | |
) | |
from haystack.pipelines import ExtractiveQAPipeline, Pipeline | |
from sentence_transformers import SentenceTransformer | |
import certifi | |
import datetime | |
import requests | |
from base64 import b64encode | |
ca_certs=certifi.where() | |
class PineconeRetriever(BaseComponent): | |
outgoing_edges = 1 | |
def __init__(self, sentence_transformer_name: str, api_key:str, environment: str, index_name: str): | |
# a small subset of the component's parameters is sent in an event after applying filters defined in haystack.telemetry.NonPrivateParameters | |
self.sts_model = SentenceTransformer(sentence_transformer_name) | |
pinecone.init(api_key = api_key, environment=environment) | |
self.index = pinecone.Index(index_name) | |
def run(self, query: str, top_k: Optional[int]): | |
# process the inputs | |
vector_embeddings = self.sts_model.encode(query).tolist() | |
response = self.index.query([vector_embeddings], top_k=top_k, include_metadata=True) | |
docs = [ | |
Document( | |
content=d["metadata"]['content'], | |
meta={'title': d["metadata"]['title'], | |
'page': d["metadata"]['page'], | |
'source': d["metadata"]['source'] | |
} | |
) | |
for d in response["matches"] | |
] | |
output = {"documents": docs, "query": query} | |
return output, "output_1" | |
def run_batch(self, queries: List[str], top_k: Optional[int]): | |
return {}, "output_1" | |
class DocumentQueries(ABC): | |
def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str): | |
pass | |
class PinecodeProposalQueries(DocumentQueries): | |
def __init__(self, es_host: str, es_index: str, es_user, es_password, reader_name_or_path: str, use_gpu = True) -> None: | |
reader = FARMReader(model_name_or_path = reader_name_or_path, use_gpu = use_gpu, num_processes=1, context_window_size=200) | |
self._initialize_pipeline(es_host, es_index, es_user, es_password, reader = reader) | |
#self.log = Log(es_host= es_host, es_index="log", es_user = es_user, es_password= es_password) | |
def _initialize_pipeline(self, es_host, es_index, es_user, es_password, reader = None): | |
if reader is not None: | |
self.reader = reader | |
#pinecone.init(api_key=es_password, environment="us-east1-gcp") | |
index_name = "semantic-text-search" | |
self.document_store = PineconeDocumentStore( | |
api_key=es_password, | |
environment = "us-east1-gcp", | |
index=index_name, | |
similarity="cosine", | |
embedding_dim=384 | |
) | |
self.pipe = Pipeline() | |
pinecone_retriever = PineconeRetriever("sentence-transformers/multi-qa-MiniLM-L6-cos-v1", | |
es_password, "us-east1-gcp", | |
index_name) | |
self.pipe.add_node(component=pinecone_retriever, name="Retriever", inputs=["Query"]) | |
self.pipe.add_node(component=self.reader, name="Reader", inputs=["Retriever"]) | |
# #self.retriever = BM25Retriever(document_store = self.document_store) | |
# self.retriever = EmbeddingRetriever( | |
# document_store=self.document_store, | |
# #embedding_model="multi-qa-distilbert-dot-v1", | |
# embedding_model = "sentence-transformers/msmarco-MiniLM-L6-cos-v5", | |
# model_format="sentence_transformers" | |
# ) | |
# retriever_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
#self.document_store.update_embeddings(self.retriever, update_existing_embeddings=False) | |
#self.pipe = ExtractiveQAPipeline (reader = self.reader, retriever = self.retriever) | |
#self.pipe = DocumentSearchPipeline(self.retriever) | |
def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str = None) : | |
#self.document_store.update_embeddings(self.retriever, update_existing_embeddings=False) | |
#if es_index is not None: | |
#self._initialize_pipeline(self.es_host, es_index, self.es_user, self.es_password) | |
params = {"Retriever": {"top_k": retriever_top_k}, "Reader": {"top_k": reader_top_k}} | |
#params = {"Retriever": {"top_k": retriever_top_k}} | |
prediction = self.pipe.run( query = query, params = params) | |
return prediction["answers"] | |
class Log(): | |
def __init__(self, es_host: str, es_index: str, es_user, es_password) -> None: | |
self.elastic_endpoint = f"https://{es_host}:443/{es_index}/_doc" | |
self.credentials = b64encode(b"3pvrzh9tl:4yl4vk9ijr").decode("ascii") | |
self.auth_header = { 'Authorization' : 'Basic %s' % self.credentials } | |
def write_log(self, message: str, source: str) -> None: | |
created_date = datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ') | |
post_data = { | |
"message" : message, | |
"createdDate": { | |
"date" : created_date | |
}, | |
"source": source | |
} | |
r = requests.post(self.elastic_endpoint, json = post_data, headers = self.auth_header) | |
print(r.text) |