import asyncio import logging from typing import List, Optional, Sequence from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.retrievers import BaseRetriever from langchain.chains.llm import LLMChain logger = logging.getLogger(__name__) class LineListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of lines.""" def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") return lines # Default prompt DEFAULT_QUERY_PROMPT = PromptTemplate( input_variables=["question"], template="""You are an AI language model assistant. Your task is to generate 3 different versions of the given user question to retrieve relevant documents from a vector database. By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of distance-based similarity search. Provide these alternative questions separated by newlines. Original question: {question}""", ) def _unique_documents(documents: Sequence[Document]) -> List[Document]: return [doc for i, doc in enumerate(documents) if doc not in documents[:i]][:4] class MultiQueryRetriever(BaseRetriever): """Given a query, use an LLM to write a set of queries. Retrieve docs for each query. Return the unique union of all retrieved docs. """ retriever: BaseRetriever llm_chain: LLMChain verbose: bool = True parser_key: str = "lines" """DEPRECATED. parser_key is no longer used and should not be specified.""" include_original: bool = False """Whether to include the original query in the list of generated queries.""" @classmethod def from_llm( cls, retriever: BaseRetriever, llm: BaseLanguageModel, prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, parser_key: Optional[str] = None, include_original: bool = False, ) -> "MultiQueryRetriever": """Initialize from llm using default template. Args: retriever: retriever to query documents from llm: llm for query generation using DEFAULT_QUERY_PROMPT include_original: Whether to include the original query in the list of generated queries. Returns: MultiQueryRetriever """ output_parser = LineListOutputParser() llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser) return cls( retriever=retriever, llm_chain=llm_chain, include_original=include_original, ) async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, ) -> List[Document]: """Get relevant documents given a user query. Args: question: user query Returns: Unique union of relevant documents from all generated queries """ queries = await self.agenerate_queries(query, run_manager) if self.include_original: queries.append(query) documents = await self.aretrieve_documents(queries, run_manager) return self.unique_union(documents) async def agenerate_queries( self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[str]: """Generate queries based upon user input. Args: question: user query Returns: List of LLM generated queries that are similar to the user input """ response = await self.llm_chain.acall( inputs={"question": question}, callbacks=run_manager.get_child() ) lines = response["text"] if self.verbose: logger.info(f"Generated queries: {lines}") return lines async def aretrieve_documents( self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Run all LLM generated queries. Args: queries: query list Returns: List of retrieved Documents """ document_lists = await asyncio.gather( *( self.retriever.aget_relevant_documents( query, callbacks=run_manager.get_child() ) for query in queries ) ) return [doc for docs in document_lists for doc in docs] def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """Get relevant documents given a user query. Args: question: user query Returns: Unique union of relevant documents from all generated queries """ queries = self.generate_queries(query, run_manager) if self.include_original: queries.append(query) documents = self.retrieve_documents(queries, run_manager) return self.unique_union(documents) def generate_queries( self, question: str, run_manager: CallbackManagerForRetrieverRun ) -> List[str]: """Generate queries based upon user input. Args: question: user query Returns: List of LLM generated queries that are similar to the user input """ response = self.llm_chain( {"question": question}, callbacks=run_manager.get_child() ) lines = response["text"] if self.verbose: logger.info(f"Generated queries: {lines}") return lines def retrieve_documents( self, queries: List[str], run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Run all LLM generated queries. Args: queries: query list Returns: List of retrieved Documents """ documents = [] for query in queries: docs = self.retriever.get_relevant_documents( query, callbacks=run_manager.get_child() ) documents.extend(docs) print("retrieve documents--", len(documents)) return documents def unique_union(self, documents: List[Document]) -> List[Document]: """Get unique Documents. Args: documents: List of retrieved Documents Returns: List of unique retrieved Documents """ print("unique union--", len(documents)) return _unique_documents(documents)