File size: 3,621 Bytes
b4bdfee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
LLM chain retrieval
"""

import json
import gradio as gr

from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate


# Add system template for RAG application
PROMPT_TEMPLATE = """
You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end. 
If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise.
Question: {question} 
Context: {context} 
Helpful Answer:
"""


# Initialize langchain LLM chain
def initialize_llmchain(
    llm_model,
    huggingfacehub_api_token,
    temperature,
    max_tokens,
    top_k,
    vector_db,
    progress=gr.Progress(),
):
    """Initialize Langchain LLM chain"""

    progress(0.1, desc="Initializing HF tokenizer...")
    # HuggingFaceHub uses HF inference endpoints
    progress(0.5, desc="Initializing HF Hub...")
    # Use of trust_remote_code as model_kwargs
    # Warning: langchain issue
    # URL: https://github.com/langchain-ai/langchain/issues/6080

    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        task="text-generation",
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k,
        huggingfacehub_api_token=huggingfacehub_api_token,
    )

    progress(0.75, desc="Defining buffer memory...")
    memory = ConversationBufferMemory(
        memory_key="chat_history", output_key="answer", return_messages=True
    )
    # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
    retriever = vector_db.as_retriever()

    progress(0.8, desc="Defining retrieval chain...")
    with open('prompt_template.json', 'r') as file:
    	system_prompt = json.load(file)
    prompt_template = system_prompt["prompt"]
    rag_prompt = PromptTemplate(
        template=prompt_template, input_variables=["context", "question"]
    )
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        combine_docs_chain_kwargs={"prompt": rag_prompt},
        return_source_documents=True,
        # return_generated_question=False,
        verbose=False,
    )
    progress(0.9, desc="Done!")

    return qa_chain


def format_chat_history(message, chat_history):
    """Format chat history for llm chain"""

    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history


def invoke_qa_chain(qa_chain, message, history):
    """Invoke question-answering chain"""

    formatted_chat_history = format_chat_history(message, history)
    # print("formatted_chat_history",formatted_chat_history)

    # Generate response using QA chain
    response = qa_chain.invoke(
        {"question": message, "chat_history": formatted_chat_history}
    )

    response_sources = response["source_documents"]

    response_answer = response["answer"]
    if response_answer.find("Helpful Answer:") != -1:
        response_answer = response_answer.split("Helpful Answer:")[-1]

    # Append user message and response to chat history
    new_history = history + [(message, response_answer)]

    # print ('chat response: ', response_answer)
    # print('DB source', response_sources)

    return qa_chain, new_history, response_sources