Spaces:
Sleeping
Sleeping
""" | |
/************************************************************************* | |
* | |
* CONFIDENTIAL | |
* __________________ | |
* | |
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC | |
* All Rights Reserved | |
* | |
* Author : Theekshana Samaradiwakara | |
* Description :Python Backend API to chat with private data | |
* CreatedDate : 14/11/2023 | |
* LastModifiedDate : 21/03/2024 | |
*************************************************************************/ | |
""" | |
import asyncio | |
import logging | |
from typing import List, Optional, Sequence | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForRetrieverRun, | |
CallbackManagerForRetrieverRun, | |
) | |
from langchain_core.documents import Document | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.output_parsers import BaseOutputParser | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_core.retrievers import BaseRetriever | |
from langchain.chains.llm import LLMChain | |
import numpy as np | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
from reggpt.prompts.multi_query import MULTY_QUERY_PROMPT | |
class LineListOutputParser(BaseOutputParser[List[str]]): | |
"""Output parser for a list of lines.""" | |
def parse(self, text: str) -> List[str]: | |
lines = text.strip().split("\n") | |
return lines | |
# Default prompt | |
# DEFAULT_QUERY_PROMPT = PromptTemplate( | |
# input_variables=["question"], | |
# template="""You are an AI language model assistant. Your task is | |
# to generate 3 different versions of the given user | |
# question to retrieve relevant documents from a vector database. | |
# By generating multiple perspectives on the user question, | |
# your goal is to help the user overcome some of the limitations | |
# of distance-based similarity search. Provide these alternative | |
# questions separated by newlines. Original question: {question}""", | |
# ) | |
def _unique_documents(documents: Sequence[Document]) -> List[Document]: | |
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]] | |
class MultiQueryRetriever(BaseRetriever): | |
"""Given a query, use an LLM to write a set of queries. | |
Retrieve docs for each query. Return the unique union of all retrieved docs. | |
""" | |
retriever: BaseRetriever | |
llm_chain: LLMChain | |
verbose: bool = True | |
parser_key: str = "lines" | |
"""DEPRECATED. parser_key is no longer used and should not be specified.""" | |
include_original: bool = False | |
"""Whether to include the original query in the list of generated queries.""" | |
date_key: str = "year" | |
top_k: int = 4 | |
def from_llm( | |
cls, | |
retriever: BaseRetriever, | |
llm: BaseLanguageModel, | |
prompt: PromptTemplate = MULTY_QUERY_PROMPT, | |
parser_key: Optional[str] = None, | |
include_original: bool = False, | |
) -> "MultiQueryRetriever": | |
"""Initialize from llm using default template. | |
Args: | |
retriever: retriever to query documents from | |
llm: llm for query generation using DEFAULT_QUERY_PROMPT | |
include_original: Whether to include the original query in the list of | |
generated queries. | |
Returns: | |
MultiQueryRetriever | |
""" | |
output_parser = LineListOutputParser() | |
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser) | |
return cls( | |
retriever=retriever, | |
llm_chain=llm_chain, | |
include_original=include_original, | |
) | |
async def _aget_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: AsyncCallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
"""Get relevant documents given a user query. | |
Args: | |
question: user query | |
Returns: | |
Unique union of relevant documents from all generated queries | |
""" | |
queries = await self.agenerate_queries(query, run_manager) | |
if self.include_original: | |
queries.append(query) | |
documents = await self.aretrieve_documents(queries, run_manager) | |
return self.unique_union(documents) | |
async def agenerate_queries( | |
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[str]: | |
"""Generate queries based upon user input. | |
Args: | |
question: user query | |
Returns: | |
List of LLM generated queries that are similar to the user input | |
""" | |
response = await self.llm_chain.ainvoke( | |
inputs={"question": question}, callbacks=run_manager.get_child() | |
) | |
lines = response["text"] | |
if self.verbose: | |
logger.info(f"Generated queries: {lines}") | |
return lines | |
async def aretrieve_documents( | |
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[Document]: | |
"""Run all LLM generated queries. | |
Args: | |
queries: query list | |
Returns: | |
List of retrieved Documents | |
""" | |
document_lists = await asyncio.gather( | |
*( | |
self.retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
for query in queries | |
) | |
) | |
return [doc for docs in document_lists for doc in docs] | |
def _get_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: CallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
"""Get relevant documents given a user query. | |
Args: | |
question: user query | |
Returns: | |
Unique union of relevant documents from all generated queries | |
""" | |
queries = self.generate_queries(query, run_manager) | |
if self.include_original: | |
queries.append(query) | |
documents = self.retrieve_documents(queries, run_manager) | |
fused_documents= self.unique_union(documents) | |
# check for key exists | |
if fused_documents[0].metadata[self.date_key] != None: | |
doc_dates = pd.to_datetime( | |
[doc.metadata[self.date_key] for doc in fused_documents] | |
) | |
sorted_node_idxs = np.flip(doc_dates.argsort()) | |
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs] | |
logger.info('Documents sorted by year') | |
return fused_documents[:self.top_k] | |
def generate_queries( | |
self, question: str, run_manager: CallbackManagerForRetrieverRun | |
) -> List[str]: | |
"""Generate queries based upon user input. | |
Args: | |
question: user query | |
Returns: | |
List of LLM generated queries that are similar to the user input | |
""" | |
response = self.llm_chain.invoke( | |
{"question": question}, callbacks=run_manager.get_child() | |
) | |
lines = response["text"] | |
if self.verbose: | |
logger.info(f"Generated queries: {lines}") | |
return lines | |
def retrieve_documents( | |
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
"""Run all LLM generated queries. | |
Args: | |
queries: query list | |
Returns: | |
List of retrieved Documents | |
""" | |
documents = [] | |
for query in queries: | |
logger.info(f"MQ Retriever question: {query}") | |
docs = self.retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
documents.extend(docs) | |
return documents | |
def unique_union(self, documents: List[Document]) -> List[Document]: | |
"""Get unique Documents. | |
Args: | |
documents: List of retrieved Documents | |
Returns: | |
List of unique retrieved Documents | |
""" | |
return _unique_documents(documents) | |