File size: 2,405 Bytes
f0fc5f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# https://python.langchain.com/docs/modules/chains/how_to/custom_chain
# Including reformulation of the question in the chain
import json

from langchain import PromptTemplate, LLMChain
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains import TransformChain, SequentialChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain

from climateqa.prompts import answer_prompt, reformulation_prompt,audience_prompts


def load_reformulation_chain(llm):

    prompt = PromptTemplate(
        template = reformulation_prompt,
        input_variables=["query"],
    )
    reformulation_chain = LLMChain(llm = llm,prompt = prompt,output_key="json")

    # Parse the output
    def parse_output(output):
        query = output["query"]
        json_output = json.loads(output["json"])
        question = json_output.get("question", query)
        language = json_output.get("language", "English")
        return {
            "question": question,
            "language": language,
        }
    
    transform_chain = TransformChain(
        input_variables=["json"], output_variables=["question","language"], transform=parse_output
    )

    reformulation_chain = SequentialChain(chains = [reformulation_chain,transform_chain],input_variables=["query"],output_variables=["question","language"])
    return reformulation_chain



def load_answer_chain(retriever,llm):
    prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
    qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)

    # This could be improved by providing a document prompt to avoid modifying page_content in the docs
    # See here https://github.com/langchain-ai/langchain/issues/3523

    answer_chain = RetrievalQAWithSourcesChain(
        combine_documents_chain = qa_chain,
        retriever=retriever,
        return_source_documents = True,
    )
    return answer_chain


def load_climateqa_chain(retriever,llm):

    reformulation_chain = load_reformulation_chain(llm)
    answer_chain = load_answer_chain(retriever,llm)

    climateqa_chain = SequentialChain(
        chains = [reformulation_chain,answer_chain],
        input_variables=["query","audience"],
        output_variables=["answer","question","language","source_documents"],
        return_all = True,
    )
    return climateqa_chain