ppsingh commited on
Commit
723ac7e
1 Parent(s): d4a2dd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -8
app.py CHANGED
@@ -8,9 +8,16 @@ import re
8
  import json
9
  from auditqa.sample_questions import QUESTIONS
10
  from auditqa.reports import POSSIBLE_REPORTS
11
- from auditqa.engine.prompts import audience_prompts
12
  from auditqa.doc_process import process_pdf
13
- process_pdf()
 
 
 
 
 
 
 
14
 
15
  async def chat(query,history,audience,sources,reports):
16
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
@@ -21,6 +28,9 @@ async def chat(query,history,audience,sources,reports):
21
  print(f"audience:{audience}")
22
  print(f"sources:{sources}")
23
  print(f"reports:{reports}")
 
 
 
24
 
25
  if audience == "Children":
26
  audience_prompt = audience_prompts["children"]
@@ -33,20 +43,101 @@ async def chat(query,history,audience,sources,reports):
33
 
34
  # Prepare default values
35
  if len(sources) == 0:
36
- sources = ["IPCC"]
37
 
38
  if len(reports) == 0:
39
  reports = []
40
 
41
- history = [tuple(x) for x in history]
 
 
 
42
 
43
- docs_html = ""
44
- output_query = ""
45
- output_language = "ENG"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  yield history,docs_html,output_query,output_language
48
 
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # --------------------------------------------------------------------
51
  # Gradio
52
  # --------------------------------------------------------------------
 
8
  import json
9
  from auditqa.sample_questions import QUESTIONS
10
  from auditqa.reports import POSSIBLE_REPORTS
11
+ from auditqa.engine.prompts import audience_prompts, answer_prompt_template
12
  from auditqa.doc_process import process_pdf
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain.llms import HuggingFaceEndpoint
16
+ from dotenv import load_dotenv
17
+ load_dotenv()
18
+
19
+ HF_token = os.environ["HF_TOKEN"]
20
+ vectorstores = process_pdf()
21
 
22
  async def chat(query,history,audience,sources,reports):
23
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
 
28
  print(f"audience:{audience}")
29
  print(f"sources:{sources}")
30
  print(f"reports:{reports}")
31
+ docs_html = ""
32
+ output_query = ""
33
+ output_language = "english"
34
 
35
  if audience == "Children":
36
  audience_prompt = audience_prompts["children"]
 
43
 
44
  # Prepare default values
45
  if len(sources) == 0:
46
+ sources = ["ABC"]
47
 
48
  if len(reports) == 0:
49
  reports = []
50
 
51
+ if sources == ["ABC"]:
52
+ vectorstore = vectorstores["ABC"]
53
+ else:
54
+ vectorstore = vectorstores["XYZ"]
55
 
56
+ # get context
57
+ context_retrieved_lst = []
58
+ question_lst= [query]
59
+ for question in question_lst:
60
+ retriever = vectorstore.as_retriever(
61
+ search_type="similarity",
62
+ search_kwargs={"k": 1})
63
+
64
+ context_retrieved = retriever.get_relevant_documents(question)
65
+
66
+ def format_docs(docs):
67
+ return "\n\n".join(doc.page_content for doc in docs)
68
+
69
+ context_retrieved_formatted = format_docs(context_retrieved)
70
+ context_retrieved_lst.append(context_retrieved_formatted)
71
+
72
+
73
+ # get prompt
74
+ prompt = ChatPromptTemplate.from_template(answer_prompt_template)
75
+
76
+ # get llm
77
+ llm_qa = HuggingFaceEndpoint(
78
+ endpoint_url= "https://fesg9gjsfde5yfr4.us-east-1.aws.endpoints.huggingface.cloud",
79
+ task="text-generation",
80
+ huggingfacehub_api_token=HF_token,
81
+ model_kwargs={})
82
+
83
+ # create rag chain
84
+ chain = prompt | llm_qa | StrOutputParser()
85
+ # get answers
86
+ answer_lst = []
87
+ for question, context in zip(question_list , context_retrieved_lst):
88
+ answer = chain.invoke({"context": context, "question": question,'audience':audience_prompt, 'language':'english'})
89
+ answer_lst.append(answer)
90
+ docs_html = []
91
+ for i, d in enumerate(context_retrieved, 1):
92
+ docs_html.append(make_html_source(d, i))
93
+ docs_html = "".join(docs_html)
94
+
95
+ previous_answer = history[-1][1]
96
+ previous_answer = previous_answer if previous_answer is not None else ""
97
+ answer_yet = previous_answer + answer_lst[0]
98
+ answer_yet = parse_output_llm_with_sources(answer_yet)
99
+ history[-1] = (query,answer_yet)
100
+
101
+ history = [tuple(x) for x in history]
102
 
103
  yield history,docs_html,output_query,output_language
104
 
105
+ def make_html_source(source,i):
106
+ meta = source.metadata
107
+ # content = source.page_content.split(":",1)[1].strip()
108
+ content = source.page_content.strip()
109
+
110
+ toc_levels = []
111
+ for j in range(2):
112
+ level = meta[f"toc_level{j}"]
113
+ if level != "N/A":
114
+ toc_levels.append(level)
115
+ else:
116
+ break
117
+ toc_levels = " > ".join(toc_levels)
118
+
119
+ if len(toc_levels) > 0:
120
+ name = f"<b>{toc_levels}</b><br/>{meta['name']}"
121
+ else:
122
+ name = meta['name']
123
+
124
+ if meta["chunk_type"] == "text":
125
+
126
+ card = f"""
127
+ <div class="card" id="doc{i}">
128
+ <div class="card-content">
129
+ <h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
130
+ <p>{content}</p>
131
+ </div>
132
+ <div class="card-footer">
133
+ <span>{name}</span>
134
+ <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
135
+ <span role="img" aria-label="Open PDF">🔗</span>
136
+ </a>
137
+ </div>
138
+ </div>
139
+ """
140
+ return card
141
  # --------------------------------------------------------------------
142
  # Gradio
143
  # --------------------------------------------------------------------