File size: 2,917 Bytes
3ac9dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ce9985
3ac9dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ce9985
3ac9dae
4ce9985
3ac9dae
 
 
 
 
 
 
71f3335
3ac9dae
 
 
 
 
 
 
 
 
4ce9985
3ac9dae
 
 
 
4ce9985
 
 
3ac9dae
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import sys
from queue import Queue
from timeit import default_timer as timer

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult

from app_modules.init import app_init
from app_modules.utils import print_llm_response

llm_loader, qa_chain = app_init()


class MyCustomHandler(BaseCallbackHandler):
    def __init__(self):
        self.reset()

    def reset(self):
        self.texts = []

    def get_standalone_question(self) -> str:
        return self.texts[0].strip() if len(self.texts) > 0 else None

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        """Run when chain ends running."""
        print("\non_llm_end - response:")
        print(response)
        self.texts.append(response.generations[0][0].text)


chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
chat_id = sys.argv[2] if len(sys.argv) > 2 else None
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"

custom_handler = MyCustomHandler()

# Chatbot loop
chat_history = []
print("Welcome to the ChatPDF! Type 'exit' to stop.")

# Open the file for reading
file = open(questions_file_path, "r")

# Read the contents of the file into a list of strings
queue = file.readlines()
for i in range(len(queue)):
    queue[i] = queue[i].strip()

# Close the file
file.close()

queue.append("exit")

chat_start = timer()

while True:
    if chatting:
        query = input("Please enter your question: ")
    else:
        query = queue.pop(0)

    query = query.strip()
    if query.lower() == "exit":
        break

    print("\nQuestion: " + query)
    custom_handler.reset()

    start = timer()
    inputs = {"question": query, "chat_history": chat_history, "chat_id": chat_id}
    result = qa_chain.call_chain(
        inputs,
        custom_handler,
        None,
        True,
    )
    end = timer()
    print(f"Completed in {end - start:.3f}s")

    # print_llm_response(result)

    if len(chat_history) == 0:
        standalone_question = query
    else:
        standalone_question = custom_handler.get_standalone_question()

    if standalone_question is not None:
        print(f"Load relevant documents for standalone question: {standalone_question}")
        start = timer()
        qa = qa_chain.get_chain(inputs)
        docs = qa.retriever.get_relevant_documents(standalone_question)
        end = timer()
        print(f"Completed in {end - start:.3f}s")

        if chatting:
            print(docs)

    if chat_history_enabled == "true":
        chat_history.append((query, result["answer"]))

chat_end = timer()
total_time = chat_end - chat_start
print(f"Total time used: {total_time:.3f} s")
print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}")
print(
    f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s"
)