File size: 5,986 Bytes
a6c26b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import sys
import argparse
import pandas as pd
import time
from typing import Any, Dict, Optional
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain.prompts import load_prompt
from langchain_core.output_parsers import StrOutputParser
from transformers import AutoTokenizer

current_dir = os.path.dirname(os.path.abspath(__file__))
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))

sys.path.append(kit_dir)
sys.path.append(repo_dir)

from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain

class TimedRetrievalQAChain(RetrievalQAChain):
    #override call method to return times
    def _call(self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        qa_chain = self.qa_prompt | self.llm | StrOutputParser()
        response = {}
        start_time = time.time()
        documents = self.retriever.invoke(inputs["question"])
        if self.rerank: 
            documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents)
        docs = self._format_docs(documents)
        end_preprocessing_time=time.time()
        response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs})
        end_llm_time=time.time()
        response["source_documents"] = documents
        response["start_time"] = start_time
        response["end_preprocessing_time"] = end_preprocessing_time
        response["end_llm_time"] = end_llm_time
        return response

def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer):
    preprocessing_time=end_preprocessing_time-start_time
    llm_time=end_llm_time-end_preprocessing_time
    token_count=len(tokenizer.encode(answer))
    tokens_per_second = token_count / llm_time
    perf =   {"preprocessing_time": preprocessing_time,
             "llm_time": llm_time, 
             "token_count": token_count, 
             "tokens_per_second": tokens_per_second}
    return perf 

def generate(qa_chain, question, tokenizer):
    response  = qa_chain.invoke({"question": question})
    answer =  response.get('answer')
    sources = set([
            f'{sd.metadata["filename"]}'
            for sd in response["source_documents"]
        ])
    times = analyze_times(
        answer, 
        response.get("start_time"), 
        response.get("end_preprocessing_time"), 
        response.get("end_llm_time"),
        tokenizer
        ) 
    return answer, sources, times

def process_bulk_QA(vectordb_path, questions_file_path):
    documentRetrieval =  DocumentRetrieval()
    tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
    if os.path.exists(vectordb_path):
        # load the vectorstore
        embeddings = documentRetrieval.load_embedding_model()
        vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings)
        print("Database loaded")
        documentRetrieval.init_retriever(vectorstore)
        print("retriever initialized")
        #get qa chain
        qa_chain = TimedRetrievalQAChain(
            retriever=documentRetrieval.retriever,
            llm=documentRetrieval.llm,
            qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])),
            rerank = documentRetrieval.retrieval_info["rerank"],
            final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"]
        
        )
    else:
        raise f"vector db path {vectordb_path} does not exist"
    if os.path.exists(questions_file_path):
        df = pd.read_excel(questions_file_path)
        print(df)
        output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx')
        if 'Answer' not in df.columns:
            df['Answer'] = ''
            df['Sources'] = ''
            df['preprocessing_time'] = ''
            df['llm_time'] = ''
            df['token_count'] = ''
            df['tokens_per_second'] = ''
        for index, row in df.iterrows():
            if row['Answer'].strip()=='':  # Only process if 'Answer' is empty
                try:
                    # Generate the answer
                    print(f"Generating answer for row {index}")
                    answer, sources, times = generate(qa_chain, row['Questions'], tokenizer)
                    df.at[index, 'Answer'] = answer
                    df.at[index, 'Sources'] = sources
                    df.at[index, 'preprocessing_time'] = times.get("preprocessing_time")
                    df.at[index, 'llm_time'] = times.get("llm_time")
                    df.at[index, 'token_count'] = times.get("token_count")
                    df.at[index, 'tokens_per_second'] = times.get("tokens_per_second")
                except Exception as e:
                    print(f"Error processing row {index}: {e}")
                # Save the file after each iteration to avoid data loss
                df.to_excel(output_file_path, index=False)
            else:
                print(f"Skipping row {index} because 'Answer' is already in the document")
        return output_file_path
    else:
        raise f"questions file path {questions_file_path} does not exist"
                                                      
if __name__ == "__main__":
    # Parse the arguments
    parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions')
    parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG')
    parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions')
    args = parser.parse_args()
    # process in bulk 
    out_file = process_bulk_QA(args.vectordb_path, args.questions_path)
    print(f"Finished, responses in: {out_file}")