Add content recommandation

#17
by timeki - opened
Files changed (47) hide show
  1. .gitignore +6 -0
  2. README.md +1 -1
  3. app.py +402 -403
  4. climateqa/constants.py +22 -1
  5. climateqa/engine/chains/__init__.py +0 -0
  6. climateqa/engine/chains/answer_ai_impact.py +46 -0
  7. climateqa/engine/chains/answer_chitchat.py +56 -0
  8. climateqa/engine/{rag.py → chains/answer_rag.py} +36 -58
  9. climateqa/engine/chains/chitchat_categorization.py +43 -0
  10. climateqa/engine/chains/graph_retriever.py +128 -0
  11. climateqa/engine/chains/intent_categorization.py +90 -0
  12. climateqa/engine/chains/keywords_extraction.py +40 -0
  13. climateqa/engine/{prompts.py → chains/prompts.py} +26 -3
  14. climateqa/engine/chains/query_transformation.py +201 -0
  15. climateqa/engine/{reformulation.py → chains/reformulation.py} +1 -1
  16. climateqa/engine/chains/retrieve_documents.py +303 -0
  17. climateqa/engine/chains/retrieve_papers.py +95 -0
  18. climateqa/engine/chains/retriever.py +126 -0
  19. climateqa/engine/chains/sample_router.py +66 -0
  20. climateqa/engine/chains/set_defaults.py +13 -0
  21. climateqa/engine/chains/translation.py +42 -0
  22. climateqa/engine/embeddings.py +6 -3
  23. climateqa/engine/graph.py +190 -0
  24. climateqa/engine/graph_retriever.py +88 -0
  25. climateqa/engine/keywords.py +3 -1
  26. climateqa/engine/llm/__init__.py +3 -0
  27. climateqa/engine/llm/ollama.py +6 -0
  28. climateqa/engine/llm/openai.py +1 -1
  29. climateqa/engine/reranker.py +50 -0
  30. climateqa/engine/retriever.py +0 -163
  31. climateqa/engine/utils.py +17 -0
  32. climateqa/engine/vectorstore.py +8 -2
  33. climateqa/event_handler.py +123 -0
  34. climateqa/knowledge/__init__.py +0 -0
  35. climateqa/{papers → knowledge}/openalex.py +68 -15
  36. climateqa/knowledge/retriever.py +102 -0
  37. climateqa/papers/__init__.py +0 -43
  38. climateqa/utils.py +13 -0
  39. front/__init__.py +0 -0
  40. front/callbacks.py +0 -0
  41. front/utils.py +335 -0
  42. requirements.txt +14 -6
  43. sandbox/20240310 - CQA - Semantic Routing 1.ipynb +0 -0
  44. sandbox/20240702 - CQA - Graph Functionality.ipynb +0 -0
  45. sandbox/20241104 - CQA - StepByStep CQA.ipynb +0 -0
  46. style.css +403 -66
  47. test.json +0 -0
.gitignore CHANGED
@@ -5,3 +5,9 @@ __pycache__/utils.cpython-38.pyc
5
 
6
  notebooks/
7
  *.pyc
 
 
 
 
 
 
 
5
 
6
  notebooks/
7
  *.pyc
8
+
9
+ **/.ipynb_checkpoints/
10
+ **/.flashrank_cache/
11
+
12
+ data/
13
+ sandbox/
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.19.1
8
  app_file: app.py
9
  fullWidth: true
10
  pinned: false
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.0.2
8
  app_file: app.py
9
  fullWidth: true
10
  pinned: false
app.py CHANGED
@@ -1,13 +1,12 @@
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
4
- from climateqa.papers.openalex import OpenAlex
5
  from sentence_transformers import CrossEncoder
6
 
7
- reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
8
- oa = OpenAlex()
9
 
10
  import gradio as gr
 
11
  import pandas as pd
12
  import numpy as np
13
  import os
@@ -15,6 +14,8 @@ import time
15
  import re
16
  import json
17
 
 
 
18
  # from gradio_modal import Modal
19
 
20
  from io import BytesIO
@@ -25,20 +26,29 @@ from azure.storage.fileshare import ShareServiceClient
25
 
26
  from utils import create_user_id
27
 
 
 
 
28
 
 
29
 
30
  # ClimateQ&A imports
31
  from climateqa.engine.llm import get_llm
32
- from climateqa.engine.rag import make_rag_chain
33
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
34
- from climateqa.engine.retriever import ClimateQARetriever
 
35
  from climateqa.engine.embeddings import get_embeddings_function
36
- from climateqa.engine.prompts import audience_prompts
37
  from climateqa.sample_questions import QUESTIONS
38
- from climateqa.constants import POSSIBLE_REPORTS
39
  from climateqa.utils import get_image_from_azure_blob_storage
40
- from climateqa.engine.keywords import make_keywords_chain
41
- from climateqa.engine.rag import make_rag_papers_chain
 
 
 
 
 
42
 
43
  # Load environment variables in local mode
44
  try:
@@ -47,6 +57,8 @@ try:
47
  except Exception as e:
48
  pass
49
 
 
 
50
  # Set up Gradio Theme
51
  theme = gr.themes.Base(
52
  primary_hue="blue",
@@ -80,134 +92,115 @@ share_client = service.get_share_client(file_share_name)
80
  user_id = create_user_id()
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- def parse_output_llm_with_sources(output):
85
- # Split the content into a list of text and "[Doc X]" references
86
- content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
87
- parts = []
88
- for part in content_parts:
89
- if part.startswith("Doc"):
90
- subparts = part.split(",")
91
- subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
92
- subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
93
- parts.append("".join(subparts))
94
- else:
95
- parts.append(part)
96
- content_parts = "".join(parts)
97
- return content_parts
98
 
99
 
100
  # Create vectorstore and retriever
101
- vectorstore = get_pinecone_vectorstore(embeddings_function)
102
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
103
-
104
 
105
- def make_pairs(lst):
106
- """from a list of even lenght, make tupple pairs"""
107
- return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
108
-
109
-
110
- def serialize_docs(docs):
111
- new_docs = []
112
- for doc in docs:
113
- new_doc = {}
114
- new_doc["page_content"] = doc.page_content
115
- new_doc["metadata"] = doc.metadata
116
- new_docs.append(new_doc)
117
- return new_docs
118
 
 
119
 
 
 
 
120
 
121
- async def chat(query,history,audience,sources,reports):
122
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
123
  (messages in gradio format, messages in langchain format, source documents)"""
124
 
125
- print(f">> NEW QUESTION : {query}")
 
126
 
127
- if audience == "Children":
128
- audience_prompt = audience_prompts["children"]
129
- elif audience == "General public":
130
- audience_prompt = audience_prompts["general"]
131
- elif audience == "Experts":
132
- audience_prompt = audience_prompts["experts"]
133
- else:
134
- audience_prompt = audience_prompts["experts"]
135
 
136
  # Prepare default values
137
- if len(sources) == 0:
138
- sources = ["IPCC"]
139
 
140
- if len(reports) == 0:
141
  reports = []
142
-
143
- retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
144
- rag_chain = make_rag_chain(retriever,llm)
145
 
146
- inputs = {"query": query,"audience": audience_prompt}
147
- result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
148
- # result = rag_chain.stream(inputs)
149
 
150
- path_reformulation = "/logs/reformulation/final_output"
151
- path_keywords = "/logs/keywords/final_output"
152
- path_retriever = "/logs/find_documents/final_output"
153
- path_answer = "/logs/answer/streamed_output_str/-"
154
 
 
 
 
155
  docs_html = ""
156
  output_query = ""
157
  output_language = ""
158
  output_keywords = ""
159
- gallery = []
160
-
 
 
 
 
 
 
 
 
 
 
161
  try:
162
- async for op in result:
 
 
163
 
164
- op = op.ops[0]
165
-
166
- if op['path'] == path_reformulation: # reforulated question
167
- try:
168
- output_language = op['value']["language"] # str
169
- output_query = op["value"]["question"]
170
- except Exception as e:
171
- raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
172
-
173
- if op["path"] == path_keywords:
174
- try:
175
- output_keywords = op['value']["keywords"] # str
176
- output_keywords = " AND ".join(output_keywords)
177
- except Exception as e:
178
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
-
181
- elif op['path'] == path_retriever: # documents
182
- try:
183
- docs = op['value']['docs'] # List[Document]
184
- docs_html = []
185
- for i, d in enumerate(docs, 1):
186
- docs_html.append(make_html_source(d, i))
187
- docs_html = "".join(docs_html)
188
- except TypeError:
189
- print("No documents found")
190
- print("op: ",op)
191
- continue
192
-
193
- elif op['path'] == path_answer: # final answer
194
- new_token = op['value'] # str
195
- # time.sleep(0.01)
196
- previous_answer = history[-1][1]
197
- previous_answer = previous_answer if previous_answer is not None else ""
198
- answer_yet = previous_answer + new_token
199
- answer_yet = parse_output_llm_with_sources(answer_yet)
200
- history[-1] = (query,answer_yet)
201
-
202
-
203
-
204
- else:
205
- continue
206
-
207
- history = [tuple(x) for x in history]
208
- yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
209
-
210
  except Exception as e:
 
211
  raise gr.Error(f"{e}")
212
 
213
 
@@ -216,7 +209,7 @@ async def chat(query,history,audience,sources,reports):
216
  if os.getenv("GRADIO_ENV") != "local":
217
  timestamp = str(datetime.now().timestamp())
218
  file = timestamp + ".json"
219
- prompt = history[-1][0]
220
  logs = {
221
  "user_id": str(user_id),
222
  "prompt": prompt,
@@ -224,7 +217,7 @@ async def chat(query,history,audience,sources,reports):
224
  "question":output_query,
225
  "sources":sources,
226
  "docs":serialize_docs(docs),
227
- "answer": history[-1][1],
228
  "time": timestamp,
229
  }
230
  log_on_azure(file, logs, share_client)
@@ -232,119 +225,7 @@ async def chat(query,history,audience,sources,reports):
232
  print(f"Error logging on Azure Blob Storage: {e}")
233
  raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
234
 
235
- image_dict = {}
236
- for i,doc in enumerate(docs):
237
-
238
- if doc.metadata["chunk_type"] == "image":
239
- try:
240
- key = f"Image {i+1}"
241
- image_path = doc.metadata["image_path"].split("documents/")[1]
242
- img = get_image_from_azure_blob_storage(image_path)
243
-
244
- # Convert the image to a byte buffer
245
- buffered = BytesIO()
246
- img.save(buffered, format="PNG")
247
- img_str = base64.b64encode(buffered.getvalue()).decode()
248
-
249
- # Embedding the base64 string in Markdown
250
- markdown_image = f"![Alt text](data:image/png;base64,{img_str})"
251
- image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]}
252
- except Exception as e:
253
- print(f"Skipped adding image {i} because of {e}")
254
-
255
- if len(image_dict) > 0:
256
-
257
- gallery = [x["img"] for x in list(image_dict.values())]
258
- img = list(image_dict.values())[0]
259
- img_md = img["md"]
260
- img_caption = img["caption"]
261
- img_code = img["figure_code"]
262
- if img_code != "N/A":
263
- img_name = f"{img['key']} - {img['figure_code']}"
264
- else:
265
- img_name = f"{img['key']}"
266
-
267
- answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"
268
- history[-1] = (history[-1][0],answer_yet)
269
- history = [tuple(x) for x in history]
270
-
271
- # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
272
- # if len(gallery) > 0:
273
- # gallery = list(set("|".join(gallery).split("|")))
274
- # gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
275
-
276
- yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
277
-
278
-
279
- def make_html_source(source,i):
280
- meta = source.metadata
281
- # content = source.page_content.split(":",1)[1].strip()
282
- content = source.page_content.strip()
283
-
284
- toc_levels = []
285
- for j in range(2):
286
- level = meta[f"toc_level{j}"]
287
- if level != "N/A":
288
- toc_levels.append(level)
289
- else:
290
- break
291
- toc_levels = " > ".join(toc_levels)
292
-
293
- if len(toc_levels) > 0:
294
- name = f"<b>{toc_levels}</b><br/>{meta['name']}"
295
- else:
296
- name = meta['name']
297
-
298
- if meta["chunk_type"] == "text":
299
-
300
- card = f"""
301
- <div class="card" id="doc{i}">
302
- <div class="card-content">
303
- <h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
304
- <p>{content}</p>
305
- </div>
306
- <div class="card-footer">
307
- <span>{name}</span>
308
- <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
309
- <span role="img" aria-label="Open PDF">🔗</span>
310
- </a>
311
- </div>
312
- </div>
313
- """
314
-
315
- else:
316
-
317
- if meta["figure_code"] != "N/A":
318
- title = f"{meta['figure_code']} - {meta['short_name']}"
319
- else:
320
- title = f"{meta['short_name']}"
321
-
322
- card = f"""
323
- <div class="card card-image">
324
- <div class="card-content">
325
- <h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
326
- <p>{content}</p>
327
- <p class='ai-generated'>AI-generated description</p>
328
- </div>
329
- <div class="card-footer">
330
- <span>{name}</span>
331
- <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
332
- <span role="img" aria-label="Open PDF">🔗</span>
333
- </a>
334
- </div>
335
- </div>
336
- """
337
-
338
- return card
339
-
340
-
341
-
342
- # else:
343
- # docs_string = "No relevant passages found in the climate science reports (IPCC and IPBES)"
344
- # complete_response = "**No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**"
345
- # messages.append({"role": "assistant", "content": complete_response})
346
- # gradio_format = make_pairs([a["content"] for a in messages[1:]])
347
- # yield gradio_format, messages, docs_string
348
 
349
 
350
  def save_feedback(feed: str, user_id):
@@ -368,77 +249,7 @@ def log_on_azure(file, logs, share_client):
368
  file_client.upload_file(logs)
369
 
370
 
371
- def generate_keywords(query):
372
- chain = make_keywords_chain(llm)
373
- keywords = chain.invoke(query)
374
- keywords = " AND ".join(keywords["keywords"])
375
- return keywords
376
-
377
-
378
-
379
- papers_cols_widths = {
380
- "doc":50,
381
- "id":100,
382
- "title":300,
383
- "doi":100,
384
- "publication_year":100,
385
- "abstract":500,
386
- "rerank_score":100,
387
- "is_oa":50,
388
- }
389
-
390
- papers_cols = list(papers_cols_widths.keys())
391
- papers_cols_widths = list(papers_cols_widths.values())
392
-
393
- async def find_papers(query, keywords,after):
394
-
395
- summary = ""
396
-
397
- df_works = oa.search(keywords,after = after)
398
- df_works = df_works.dropna(subset=["abstract"])
399
- df_works = oa.rerank(query,df_works,reranker)
400
- df_works = df_works.sort_values("rerank_score",ascending=False)
401
- G = oa.make_network(df_works)
402
-
403
- height = "750px"
404
- network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
405
- network_html = network.generate_html()
406
-
407
- network_html = network_html.replace("'", "\"")
408
- css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
409
- network_html = network_html + css_to_inject
410
-
411
-
412
- network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
413
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
414
- allow-scripts allow-same-origin allow-popups
415
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
416
- allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
417
-
418
-
419
- docs = df_works["content"].head(15).tolist()
420
-
421
- df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
422
- df_works["doc"] = df_works["doc"] + 1
423
- df_works = df_works[papers_cols]
424
-
425
- yield df_works,network_html,summary
426
 
427
- chain = make_rag_papers_chain(llm)
428
- result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
429
- path_answer = "/logs/StrOutputParser/streamed_output/-"
430
-
431
- async for op in result:
432
-
433
- op = op.ops[0]
434
-
435
- if op['path'] == path_answer: # reforulated question
436
- new_token = op['value'] # str
437
- summary += new_token
438
- else:
439
- continue
440
- yield df_works,network_html,summary
441
-
442
 
443
 
444
  # --------------------------------------------------------------------
@@ -457,6 +268,10 @@ Hello, I am ClimateQ&A, a conversational assistant designed to help you understa
457
  ⚠️ Limitations
458
  *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
459
 
 
 
 
 
460
  What do you want to learn ?
461
  """
462
 
@@ -467,21 +282,38 @@ def vote(data: gr.LikeData):
467
  else:
468
  print(data)
469
 
 
 
 
 
 
 
 
 
470
 
471
 
472
- with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
473
- # user_id_state = gr.State([user_id])
 
 
 
474
 
 
475
  with gr.Tab("ClimateQ&A"):
476
 
477
  with gr.Row(elem_id="chatbot-row"):
478
  with gr.Column(scale=2):
479
- # state = gr.State([system_template])
480
  chatbot = gr.Chatbot(
481
- value=[(None,init_prompt)],
482
- show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
 
 
 
 
483
  avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
484
- )#,avatar_images = ("assets/logo4.png",None))
 
 
485
 
486
  # bot.like(vote,None,None)
487
 
@@ -489,13 +321,16 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
489
 
490
  with gr.Row(elem_id = "input-message"):
491
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
492
- # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
493
-
 
 
 
494
 
495
- with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
496
 
497
 
498
- with gr.Tabs() as tabs:
499
  with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
500
 
501
  examples_hidden = gr.Textbox(visible = False)
@@ -521,102 +356,293 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
521
  )
522
 
523
  samples.append(group_examples)
 
 
 
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
- with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
527
- sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
528
- docs_textbox = gr.State("")
529
-
530
- # with Modal(visible = False) as config_modal:
531
- with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
532
-
533
- gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
534
-
535
 
536
- dropdown_sources = gr.CheckboxGroup(
537
- ["IPCC", "IPBES","IPOS"],
538
- label="Select source",
539
- value=["IPCC"],
540
- interactive=True,
541
- )
542
 
543
- dropdown_reports = gr.Dropdown(
544
- POSSIBLE_REPORTS,
545
- label="Or select specific reports",
546
- multiselect=True,
547
- value=None,
548
- interactive=True,
549
- )
550
 
551
- dropdown_audience = gr.Dropdown(
552
- ["Children","General public","Experts"],
553
- label="Select audience",
554
- value="Experts",
555
- interactive=True,
556
- )
557
 
558
- output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
559
- output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
 
 
561
 
562
 
 
 
 
 
 
 
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
 
566
  #---------------------------------------------------------------------------------------
567
  # OTHER TABS
568
  #---------------------------------------------------------------------------------------
569
 
 
570
 
571
- with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
572
- gallery_component = gr.Gallery()
573
 
574
- with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
575
 
576
- with gr.Row():
577
- with gr.Column(scale=1):
578
- query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
579
- keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
580
- after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
581
- search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
582
 
583
- with gr.Column(scale=7):
 
 
 
 
 
 
584
 
585
- with gr.Tab("Summary",elem_id="papers-summary-tab"):
586
- papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
 
 
 
 
587
 
588
- with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
589
- papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
590
 
591
- with gr.Tab("Citations network",elem_id="papers-network-tab"):
592
- citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
593
 
594
 
595
-
596
  with gr.Tab("About",elem_classes = "max-height other-tabs"):
597
  with gr.Row():
598
  with gr.Column(scale=1):
599
- gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")
600
 
601
 
602
- def start_chat(query,history):
603
- history = history + [(query,None)]
604
- history = [tuple(x) for x in history]
605
- return (gr.update(interactive = False),gr.update(selected=1),history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
  def finish_chat():
608
- return (gr.update(interactive = True,value = ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  (textbox
611
- .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
612
- .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox")
613
  .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
 
614
  )
615
 
616
  (examples_hidden
617
- .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
618
- .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples")
619
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
 
620
  )
621
 
622
 
@@ -627,51 +653,24 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
627
  return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
628
 
629
 
 
 
 
 
 
 
 
630
 
 
631
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
632
 
633
- query_papers.submit(generate_keywords,[query_papers], [keywords_papers])
634
- search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary])
635
-
636
- # # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
637
- # (textbox
638
- # .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
639
- # .success(change_tab,None,tabs)
640
- # .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
641
- # .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue = True)
642
- # .success(lambda x : textbox,[textbox],[textbox])
643
- # )
644
-
645
- # (examples_hidden
646
- # .change(answer_user_example, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
647
- # .success(change_tab,None,tabs)
648
- # .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
649
- # .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue=True)
650
- # .success(lambda x : textbox,[textbox],[textbox])
651
- # )
652
- # submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
653
- # answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
654
- # )
655
 
656
-
657
- # with Modal(visible=True) as first_modal:
658
- # gr.Markdown("# Welcome to ClimateQ&A !")
659
-
660
- # gr.Markdown("### Examples")
661
-
662
- # examples = gr.Examples(
663
- # ["Yo ça roule","ça boume"],
664
- # [examples_hidden],
665
- # examples_per_page=8,
666
- # run_on_click=False,
667
- # elem_id="examples",
668
- # api_name="examples",
669
- # )
670
-
671
-
672
- # submit.click(lambda: Modal(visible=True), None, config_modal)
673
-
674
 
675
  demo.queue()
676
 
677
- demo.launch()
 
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
 
4
  from sentence_transformers import CrossEncoder
5
 
6
+ # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
 
7
 
8
  import gradio as gr
9
+ from gradio_modal import Modal
10
  import pandas as pd
11
  import numpy as np
12
  import os
 
14
  import re
15
  import json
16
 
17
+ from gradio import ChatMessage
18
+
19
  # from gradio_modal import Modal
20
 
21
  from io import BytesIO
 
26
 
27
  from utils import create_user_id
28
 
29
+ from gradio_modal import Modal
30
+
31
+ from PIL import Image
32
 
33
+ from langchain_core.runnables.schema import StreamEvent
34
 
35
  # ClimateQ&A imports
36
  from climateqa.engine.llm import get_llm
 
37
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
38
+ # from climateqa.knowledge.retriever import ClimateQARetriever
39
+ from climateqa.engine.reranker import get_reranker
40
  from climateqa.engine.embeddings import get_embeddings_function
41
+ from climateqa.engine.chains.prompts import audience_prompts
42
  from climateqa.sample_questions import QUESTIONS
43
+ from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
44
  from climateqa.utils import get_image_from_azure_blob_storage
45
+ from climateqa.engine.graph import make_graph_agent
46
+ from climateqa.engine.embeddings import get_embeddings_function
47
+ from climateqa.engine.chains.retrieve_papers import find_papers
48
+
49
+ from front.utils import serialize_docs,process_figures
50
+
51
+ from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
52
 
53
  # Load environment variables in local mode
54
  try:
 
57
  except Exception as e:
58
  pass
59
 
60
+ import requests
61
+
62
  # Set up Gradio Theme
63
  theme = gr.themes.Base(
64
  primary_hue="blue",
 
92
  user_id = create_user_id()
93
 
94
 
95
+ CITATION_LABEL = "BibTeX citation for ClimateQ&A"
96
+ CITATION_TEXT = r"""@misc{climateqa,
97
+ author={Théo Alves Da Costa, Timothée Bohe},
98
+ title={ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
99
+ year={2024},
100
+ howpublished= {\url{https://climateqa.com}},
101
+ }
102
+ @software{climateqa,
103
+ author = {Théo Alves Da Costa, Timothée Bohe},
104
+ publisher = {ClimateQ&A},
105
+ title = {ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
106
+ }
107
+ """
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  # Create vectorstore and retriever
112
+ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX"))
113
+ vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
 
114
 
115
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
116
+ reranker = get_reranker("nano")
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
119
 
120
+ def update_config_modal_visibility(config_open):
121
+ new_config_visibility_status = not config_open
122
+ return gr.update(visible=new_config_visibility_status), new_config_visibility_status
123
 
124
+ async def chat(query, history, audience, sources, reports, relevant_content_sources, search_only):
125
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
126
  (messages in gradio format, messages in langchain format, source documents)"""
127
 
128
+ date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
129
+ print(f">> NEW QUESTION ({date_now}) : {query}")
130
 
131
+ audience_prompt = init_audience(audience)
 
 
 
 
 
 
 
132
 
133
  # Prepare default values
134
+ if sources is None or len(sources) == 0:
135
+ sources = ["IPCC", "IPBES", "IPOS"]
136
 
137
+ if reports is None or len(reports) == 0:
138
  reports = []
 
 
 
139
 
140
+ inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources, "relevant_content_sources" : relevant_content_sources, "search_only": search_only}
141
+ result = agent.astream_events(inputs,version = "v1")
 
142
 
 
 
 
 
143
 
144
+ docs = []
145
+ used_figures=[]
146
+ related_contents = []
147
  docs_html = ""
148
  output_query = ""
149
  output_language = ""
150
  output_keywords = ""
151
+ start_streaming = False
152
+ graphs_html = ""
153
+ figures = '<div class="figures-container"><p></p> </div>'
154
+
155
+ steps_display = {
156
+ "categorize_intent":("🔄️ Analyzing user message",True),
157
+ "transform_query":("🔄️ Thinking step by step to answer the question",True),
158
+ "retrieve_documents":("🔄️ Searching in the knowledge base",False),
159
+ }
160
+
161
+ used_documents = []
162
+ answer_message_content = ""
163
  try:
164
+ async for event in result:
165
+ if "langgraph_node" in event["metadata"]:
166
+ node = event["metadata"]["langgraph_node"]
167
 
168
+ if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
169
+ docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
170
+
171
+ elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
172
+
173
+ intent = event["data"]["output"]["intent"]
174
+ if "language" in event["data"]["output"]:
175
+ output_language = event["data"]["output"]["language"]
176
+ else :
177
+ output_language = "English"
178
+ history[-1].content = f"Language identified : {output_language} \n Intent identified : {intent}"
179
+
180
+
181
+ elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
182
+ event_description, display_output = steps_display[node]
183
+ if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
184
+ history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
185
+
186
+ elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
187
+ history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content)
188
+
189
+ elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
190
+ graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
191
+
192
+
193
+ if event["name"] == "transform_query" and event["event"] =="on_chain_end":
194
+ if hasattr(history[-1],"content"):
195
+ history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]])
196
+
197
+ if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
198
+ print("X")
199
 
200
+ yield history, docs_html, output_query, output_language, related_contents , graphs_html, #,output_query,output_keywords
201
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  except Exception as e:
203
+ print(event, "has failed")
204
  raise gr.Error(f"{e}")
205
 
206
 
 
209
  if os.getenv("GRADIO_ENV") != "local":
210
  timestamp = str(datetime.now().timestamp())
211
  file = timestamp + ".json"
212
+ prompt = history[1]["content"]
213
  logs = {
214
  "user_id": str(user_id),
215
  "prompt": prompt,
 
217
  "question":output_query,
218
  "sources":sources,
219
  "docs":serialize_docs(docs),
220
+ "answer": history[-1].content,
221
  "time": timestamp,
222
  }
223
  log_on_azure(file, logs, share_client)
 
225
  print(f"Error logging on Azure Blob Storage: {e}")
226
  raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
227
 
228
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  def save_feedback(feed: str, user_id):
 
249
  file_client.upload_file(logs)
250
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  # --------------------------------------------------------------------
 
268
  ⚠️ Limitations
269
  *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
270
 
271
+ 🛈 Information
272
+ Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
273
+
274
+
275
  What do you want to learn ?
276
  """
277
 
 
282
  else:
283
  print(data)
284
 
285
+ def save_graph(saved_graphs_state, embedding, category):
286
+ print(f"\nCategory:\n{saved_graphs_state}\n")
287
+ if category not in saved_graphs_state:
288
+ saved_graphs_state[category] = []
289
+ if embedding not in saved_graphs_state[category]:
290
+ saved_graphs_state[category].append(embedding)
291
+ return saved_graphs_state, gr.Button("Graph Saved")
292
+
293
 
294
 
295
+ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
296
+ chat_completed_state = gr.State(0)
297
+ current_graphs = gr.State([])
298
+ saved_graphs = gr.State({})
299
+ config_open = gr.State(False)
300
 
301
+
302
  with gr.Tab("ClimateQ&A"):
303
 
304
  with gr.Row(elem_id="chatbot-row"):
305
  with gr.Column(scale=2):
 
306
  chatbot = gr.Chatbot(
307
+ value = [ChatMessage(role="assistant", content=init_prompt)],
308
+ type = "messages",
309
+ show_copy_button=True,
310
+ show_label = False,
311
+ elem_id="chatbot",
312
+ layout = "panel",
313
  avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
314
+ max_height="80vh",
315
+ height="100vh"
316
+ )
317
 
318
  # bot.like(vote,None,None)
319
 
 
321
 
322
  with gr.Row(elem_id = "input-message"):
323
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
324
+
325
+ config_button = gr.Button("",elem_id="config-button")
326
+ # config_checkbox_button = gr.Checkbox(label = '⚙️', value="show",visible=True, interactive=True, elem_id="checkbox-config")
327
+
328
+
329
 
330
+ with gr.Column(scale=2, variant="panel",elem_id = "right-panel"):
331
 
332
 
333
+ with gr.Tabs(elem_id = "right_panel_tab") as tabs:
334
  with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
335
 
336
  examples_hidden = gr.Textbox(visible = False)
 
356
  )
357
 
358
  samples.append(group_examples)
359
+
360
+ # with gr.Tab("Configuration", id = 10, ) as tab_config:
361
+ # # gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
362
 
363
+ # pass
364
+
365
+ # with gr.Row():
366
+
367
+ # dropdown_sources = gr.CheckboxGroup(
368
+ # ["IPCC", "IPBES","IPOS"],
369
+ # label="Select source",
370
+ # value=["IPCC"],
371
+ # interactive=True,
372
+ # )
373
+ # dropdown_external_sources = gr.CheckboxGroup(
374
+ # ["IPCC figures","OpenAlex", "OurWorldInData"],
375
+ # label="Select database to search for relevant content",
376
+ # value=["IPCC figures"],
377
+ # interactive=True,
378
+ # )
379
+
380
+ # dropdown_reports = gr.Dropdown(
381
+ # POSSIBLE_REPORTS,
382
+ # label="Or select specific reports",
383
+ # multiselect=True,
384
+ # value=None,
385
+ # interactive=True,
386
+ # )
387
+
388
+ # search_only = gr.Checkbox(label="Search only without chating", value=False, interactive=True, elem_id="checkbox-chat")
389
+
390
+
391
+ # dropdown_audience = gr.Dropdown(
392
+ # ["Children","General public","Experts"],
393
+ # label="Select audience",
394
+ # value="Experts",
395
+ # interactive=True,
396
+ # )
397
+
398
+
399
+ # after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
400
+
401
 
402
+ # output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
403
+ # output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
 
 
 
 
 
 
 
404
 
 
 
 
 
 
 
405
 
406
+ # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
407
+ # # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
 
 
 
 
 
408
 
 
 
 
 
 
 
409
 
410
+ with gr.Tab("Sources",elem_id = "tab-sources",id = 1) as tab_sources:
411
+ sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
412
+
413
+
414
+
415
+ with gr.Tab("Recommended content", elem_id="tab-recommended_content",id=2) as tab_recommended_content:
416
+ with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
417
+
418
+ with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
419
+ sources_raw = gr.State()
420
+
421
+ with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
422
+ gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
423
+
424
+ show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)
425
+ show_full_size_figures.click(lambda : Modal(visible=True),None,figure_modal)
426
+
427
+ figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
428
+
429
+
430
+
431
+ with gr.Tab("Papers",elem_id = "tab-citations",id = 4) as tab_papers:
432
+ # btn_summary = gr.Button("Summary")
433
+ # Fenêtre simulée pour le Summary
434
+ with gr.Accordion(visible=True, elem_id="papers-summary-popup", label= "See summary of relevant papers", open= False) as summary_popup:
435
+ papers_summary = gr.Markdown("", visible=True, elem_id="papers-summary")
436
+
437
+ # btn_relevant_papers = gr.Button("Relevant papers")
438
+ # Fenêtre simulée pour les Relevant Papers
439
+ with gr.Accordion(visible=True, elem_id="papers-relevant-popup",label= "See relevant papers", open= False) as relevant_popup:
440
+ papers_html = gr.HTML(show_label=False, elem_id="papers-textbox")
441
+
442
+ btn_citations_network = gr.Button("Explore papers citations network")
443
+ # Fenêtre simulée pour le Citations Network
444
+ with Modal(visible=False) as papers_modal:
445
+ citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
446
+ btn_citations_network.click(lambda: Modal(visible=True), None, papers_modal)
447
+
448
+
449
+
450
+ with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
451
+
452
+ graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",elem_id="graphs-container")
453
+ current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container])
454
+
455
+ with Modal(visible=False,elem_id="modal-config") as config_modal:
456
+ gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
457
+
458
+
459
+ # with gr.Row():
460
+
461
+ dropdown_sources = gr.CheckboxGroup(
462
+ ["IPCC", "IPBES","IPOS"],
463
+ label="Select source (by default search in all sources)",
464
+ value=["IPCC"],
465
+ interactive=True,
466
+ )
467
+
468
+ dropdown_reports = gr.Dropdown(
469
+ POSSIBLE_REPORTS,
470
+ label="Or select specific reports",
471
+ multiselect=True,
472
+ value=None,
473
+ interactive=True,
474
+ )
475
+
476
+ dropdown_external_sources = gr.CheckboxGroup(
477
+ ["IPCC figures","OpenAlex", "OurWorldInData"],
478
+ label="Select database to search for relevant content",
479
+ value=["IPCC figures"],
480
+ interactive=True,
481
+ )
482
+
483
+ search_only = gr.Checkbox(label="Search only for recommended content without chating", value=False, interactive=True, elem_id="checkbox-chat")
484
+
485
+
486
+ dropdown_audience = gr.Dropdown(
487
+ ["Children","General public","Experts"],
488
+ label="Select audience",
489
+ value="Experts",
490
+ interactive=True,
491
+ )
492
+
493
+
494
+ after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
495
+
496
 
497
+ output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
498
+ output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
499
 
500
 
501
+ dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
502
+
503
+ close_config_modal = gr.Button("Validate and Close",elem_id="close-config-modal")
504
+ close_config_modal.click(fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open])
505
+ # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
506
+
507
 
508
+
509
+ config_button.click(fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open])
510
+
511
+ # with gr.Tab("OECD",elem_id = "tab-oecd",id = 6):
512
+ # oecd_indicator = "RIVER_FLOOD_RP100_POP_SH"
513
+ # oecd_topic = "climate"
514
+ # oecd_latitude = "46.8332"
515
+ # oecd_longitude = "5.3725"
516
+ # oecd_zoom = "5.6442"
517
+ # # Create the HTML content with the iframe
518
+ # iframe_html = f"""
519
+ # <iframe src="https://localdataportal.oecd.org/maps.html?indicator={oecd_indicator}&topic={oecd_topic}&latitude={oecd_latitude}&longitude={oecd_longitude}&zoom={oecd_zoom}"
520
+ # width="100%" height="600" frameborder="0" style="border:0;" allowfullscreen></iframe>
521
+ # """
522
+ # oecd_textbox = gr.HTML(iframe_html, show_label=False, elem_id="oecd-textbox")
523
+
524
+
525
 
526
 
527
  #---------------------------------------------------------------------------------------
528
  # OTHER TABS
529
  #---------------------------------------------------------------------------------------
530
 
531
+ # with gr.Tab("Settings",elem_id = "tab-config",id = 2):
532
 
533
+ # gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
 
534
 
 
535
 
536
+ # dropdown_sources = gr.CheckboxGroup(
537
+ # ["IPCC", "IPBES","IPOS", "OpenAlex"],
538
+ # label="Select source",
539
+ # value=["IPCC"],
540
+ # interactive=True,
541
+ # )
542
 
543
+ # dropdown_reports = gr.Dropdown(
544
+ # POSSIBLE_REPORTS,
545
+ # label="Or select specific reports",
546
+ # multiselect=True,
547
+ # value=None,
548
+ # interactive=True,
549
+ # )
550
 
551
+ # dropdown_audience = gr.Dropdown(
552
+ # ["Children","General public","Experts"],
553
+ # label="Select audience",
554
+ # value="Experts",
555
+ # interactive=True,
556
+ # )
557
 
 
 
558
 
559
+ # output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
560
+ # output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
561
 
562
 
 
563
  with gr.Tab("About",elem_classes = "max-height other-tabs"):
564
  with gr.Row():
565
  with gr.Column(scale=1):
 
566
 
567
 
568
+
569
+
570
+ gr.Markdown(
571
+ """
572
+ ### More info
573
+ - See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
574
+ - Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
575
+
576
+ ### Citation
577
+ """
578
+ )
579
+ with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
580
+ # # Display citation label and text)
581
+ gr.Textbox(
582
+ value=CITATION_TEXT,
583
+ label="",
584
+ interactive=False,
585
+ show_copy_button=True,
586
+ lines=len(CITATION_TEXT.split('\n')),
587
+ )
588
+
589
+
590
+
591
+ def start_chat(query,history,search_only):
592
+ history = history + [ChatMessage(role="user", content=query)]
593
+ if search_only:
594
+ return (gr.update(interactive = False),gr.update(selected=1),history)
595
+ else:
596
+ return (gr.update(interactive = False),gr.update(selected=2),history)
597
 
598
  def finish_chat():
599
+ return gr.update(interactive = True,value = "")
600
+
601
+ # Initialize visibility states
602
+ summary_visible = False
603
+ relevant_visible = False
604
+
605
+ # Functions to toggle visibility
606
+ def toggle_summary_visibility():
607
+ global summary_visible
608
+ summary_visible = not summary_visible
609
+ return gr.update(visible=summary_visible)
610
+
611
+ def toggle_relevant_visibility():
612
+ global relevant_visible
613
+ relevant_visible = not relevant_visible
614
+ return gr.update(visible=relevant_visible)
615
+
616
 
617
+ def change_completion_status(current_state):
618
+ current_state = 1 - current_state
619
+ return current_state
620
+
621
+ def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
622
+ sources_number = sources_textbox.count("<h2>")
623
+ figures_number = figures_cards.count("<h2>")
624
+ graphs_number = current_graphs.count("<iframe")
625
+ papers_number = papers_html.count("<h2>")
626
+ sources_notif_label = f"Sources ({sources_number})"
627
+ figures_notif_label = f"Figures ({figures_number})"
628
+ graphs_notif_label = f"Graphs ({graphs_number})"
629
+ papers_notif_label = f"Papers ({papers_number})"
630
+ recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
631
+
632
+ return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
633
+
634
  (textbox
635
+ .submit(start_chat, [textbox,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
636
+ .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
637
  .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
638
+ # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
639
  )
640
 
641
  (examples_hidden
642
+ .change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
643
+ .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
644
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
645
+ # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
646
  )
647
 
648
 
 
653
  return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
654
 
655
 
656
+ sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
657
+
658
+ # update sources numbers
659
+ sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
660
+ figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
661
+ current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
662
+ papers_html.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
663
 
664
+ # other questions examples
665
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
666
 
667
+ # search for papers
668
+ textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary])
669
+ examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
 
671
+ # btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
672
+ # btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  demo.queue()
675
 
676
+ demo.launch(ssr_mode=False)
climateqa/constants.py CHANGED
@@ -42,4 +42,25 @@ POSSIBLE_REPORTS = [
42
  "IPBES IAS A C5",
43
  "IPBES IAS A C6",
44
  "IPBES IAS A SPM"
45
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  "IPBES IAS A C5",
43
  "IPBES IAS A C6",
44
  "IPBES IAS A SPM"
45
+ ]
46
+
47
+ OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
48
+ 'Agricultural Regulation & Policy', 'Air Pollution',
49
+ 'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
50
+ 'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
51
+ 'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
52
+ 'Climate Change', 'Crop Yields', 'Diet Compositions',
53
+ 'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
54
+ 'Energy Prices', 'Environmental Impacts of Food Production',
55
+ 'Environmental Protection & Regulation', 'Famines', 'Farm Size',
56
+ 'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
57
+ 'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
58
+ 'Fossil Fuels', 'Future Population Growth',
59
+ 'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
60
+ 'Land Use & Yields in Agriculture', 'Lead Pollution',
61
+ 'Meat & Dairy Production', 'Metals & Minerals',
62
+ 'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
63
+ 'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
64
+ 'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
65
+ 'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
66
+ 'Water Use & Stress', 'Wildfires']
climateqa/engine/chains/__init__.py ADDED
File without changes
climateqa/engine/chains/answer_ai_impact.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate
2
+ from langchain_core.output_parsers import StrOutputParser
3
+
4
+
5
+ prompt_template = """
6
+ You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
7
+ Always stay true to climate and nature science and do not make up information.
8
+ If you do not know the answer, just say you do not know.
9
+
10
+ ## Guidelines
11
+ - Explain that the environmental impact of AI is not covered by the IPCC or IPBES reports, but you can recommend info based on the sources below
12
+ - Answer the question in the original language of the question
13
+
14
+ ## Sources
15
+ - You can propose to visit this page https://climateqa.com/docs/carbon-footprint/ to learn more about ClimateQ&A's own carbon footprint
16
+ - You can recommend to look at the work of the AI & climate expert scientist Sasha Luccioni with in in particular those papers
17
+ - Power Hungry Processing: Watts Driving the Cost of AI Deployment? - https://arxiv.org/abs/2311.16863 - about the carbon footprint at the inference stage of AI models
18
+ - Counting Carbon: A Survey of Factors Influencing the Emissions of Machine Learning - https://arxiv.org/abs/2302.08476
19
+ - Estimating the Carbon Footprint of BLOOM, a 176B Parameter Language Model - https://arxiv.org/abs/2211.02001 - about the carbon footprint of training a large language model
20
+ - You can also recommend the following tools to calculate the carbon footprint of AI models
21
+ - CodeCarbon - https://github.com/mlco2/codecarbon to measure the carbon footprint of your code
22
+ - Ecologits - https://ecologits.ai/ to measure the carbon footprint of using LLMs APIs such
23
+ """
24
+
25
+
26
+ def make_ai_impact_chain(llm):
27
+
28
+ prompt = ChatPromptTemplate.from_messages([
29
+ ("system", prompt_template),
30
+ ("user", "{question}")
31
+ ])
32
+
33
+ chain = prompt | llm | StrOutputParser()
34
+ chain = chain.with_config({"run_name":"ai_impact_chain"})
35
+
36
+ return chain
37
+
38
+ def make_ai_impact_node(llm):
39
+
40
+ ai_impact_chain = make_ai_impact_chain(llm)
41
+
42
+ async def answer_ai_impact(state,config):
43
+ answer = await ai_impact_chain.ainvoke({"question":state["user_input"]},config)
44
+ return {"answer":answer}
45
+
46
+ return answer_ai_impact
climateqa/engine/chains/answer_chitchat.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate
2
+ from langchain_core.output_parsers import StrOutputParser
3
+
4
+
5
+ chitchat_prompt_template = """
6
+ You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
7
+ Always stay true to climate and nature science and do not make up information.
8
+ If you do not know the answer, just say you do not know.
9
+
10
+ ## Guidelines
11
+ - If it's a conversational question, you can normally chat with the user
12
+ - If the question is not related to any topic about the environment, refuse to answer and politely ask the user to ask another question about the environment
13
+ - If the user ask if you speak any language, you can say you speak all languages :)
14
+ - If the user ask about the bot itself "ClimateQ&A", you can explain that you are an AI assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports and propose to visit the website here https://climateqa.com/docs/intro/ for more information
15
+ - If the question is about ESG regulations, standards, or frameworks like the CSRD, TCFD, SASB, GRI, CDP, etc., you can explain that this is not a topic covered by the IPCC or IPBES reports.
16
+ - Precise that you are specialized in finding trustworthy information from the scientific reports of the IPCC and IPBES and other scientific litterature
17
+ - If relevant you can propose up to 3 example of questions they could ask from the IPCC or IPBES reports from the examples below
18
+ - Always answer in the original language of the question
19
+
20
+ ## Examples of questions you can suggest (in the original language of the question)
21
+ "What evidence do we have of climate change?",
22
+ "Are human activities causing global warming?",
23
+ "What are the impacts of climate change?",
24
+ "Can climate change be reversed?",
25
+ "What is the difference between climate change and global warming?",
26
+ """
27
+
28
+
29
+ def make_chitchat_chain(llm):
30
+
31
+ prompt = ChatPromptTemplate.from_messages([
32
+ ("system", chitchat_prompt_template),
33
+ ("user", "{question}")
34
+ ])
35
+
36
+ chain = prompt | llm | StrOutputParser()
37
+ chain = chain.with_config({"run_name":"chitchat_chain"})
38
+
39
+ return chain
40
+
41
+
42
+
43
+ def make_chitchat_node(llm):
44
+
45
+ chitchat_chain = make_chitchat_chain(llm)
46
+
47
+ async def answer_chitchat(state,config):
48
+ print("---- Answer chitchat ----")
49
+
50
+ answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
51
+ state["answer"] = answer
52
+ return state
53
+ # return {"answer":answer}
54
+
55
+ return answer_chitchat
56
+
climateqa/engine/{rag.py → chains/answer_rag.py} RENAMED
@@ -2,15 +2,14 @@ from operator import itemgetter
2
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.output_parsers import StrOutputParser
5
- from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
6
  from langchain_core.prompts.prompt import PromptTemplate
7
  from langchain_core.prompts.base import format_document
8
 
9
- from climateqa.engine.reformulation import make_reformulation_chain
10
- from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
11
- from climateqa.engine.prompts import papers_prompt_template
12
- from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
13
- from climateqa.engine.keywords import make_keywords_chain
14
 
15
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
16
 
@@ -40,72 +39,51 @@ def get_text_docs(x):
40
  def get_image_docs(x):
41
  return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
42
 
43
-
44
- def make_rag_chain(retriever,llm):
45
-
46
- # Construct the prompt
47
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
48
- prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
49
-
50
- # ------- CHAIN 0 - Reformulation
51
- reformulation = make_reformulation_chain(llm)
52
- reformulation = prepare_chain(reformulation,"reformulation")
53
-
54
- # ------- Find all keywords from the reformulated query
55
- keywords = make_keywords_chain(llm)
56
- keywords = {"keywords":itemgetter("question") | keywords}
57
- keywords = prepare_chain(keywords,"keywords")
58
-
59
- # ------- CHAIN 1
60
- # Retrieved documents
61
- find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
62
- find_documents = prepare_chain(find_documents,"find_documents")
63
 
64
- # ------- CHAIN 2
65
- # Construct inputs for the llm
66
- input_documents = {
67
- "context":lambda x : _combine_documents(x["docs"]),
68
- **pass_values(["question","audience","language","keywords"])
69
- }
70
 
71
- # ------- CHAIN 3
72
- # Bot answer
73
- llm_final = rename_chain(llm,"answer")
74
 
75
- answer_with_docs = {
76
- "answer": input_documents | prompt | llm_final | StrOutputParser(),
77
- **pass_values(["question","audience","language","query","docs","keywords"]),
78
- }
79
 
80
- answer_without_docs = {
81
- "answer": prompt_without_docs | llm_final | StrOutputParser(),
82
- **pass_values(["question","audience","language","query","docs","keywords"]),
83
- }
84
 
85
- # def has_images(x):
86
- # image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
87
- # return len(image_docs) > 0
88
 
89
- def has_docs(x):
90
- return len(x["docs"]) > 0
91
-
92
- answer = RunnableBranch(
93
- (lambda x: has_docs(x), answer_with_docs),
94
- answer_without_docs,
95
- )
96
 
 
97
 
98
- # ------- FINAL CHAIN
99
- # Build the final chain
100
- rag_chain = reformulation | keywords | find_documents | answer
101
 
102
- return rag_chain
103
 
104
 
105
  def make_rag_papers_chain(llm):
106
 
107
  prompt = ChatPromptTemplate.from_template(papers_prompt_template)
108
-
109
  input_documents = {
110
  "context":lambda x : _combine_documents(x["docs"]),
111
  **pass_values(["question","language"])
@@ -131,4 +109,4 @@ def make_illustration_chain(llm):
131
  }
132
 
133
  illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
134
- return illustration_chain
 
2
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.output_parsers import StrOutputParser
 
5
  from langchain_core.prompts.prompt import PromptTemplate
6
  from langchain_core.prompts.base import format_document
7
 
8
+ from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
+ from climateqa.engine.chains.prompts import papers_prompt_template
10
+ import time
11
+ from ..utils import rename_chain, pass_values
12
+
13
 
14
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
15
 
 
39
  def get_image_docs(x):
40
  return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
41
 
42
+ def make_rag_chain(llm):
 
 
 
43
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
44
+ chain = ({
45
+ "context":lambda x : _combine_documents(x["documents"]),
46
+ "context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
47
+ "query":itemgetter("query"),
48
+ "language":itemgetter("language"),
49
+ "audience":itemgetter("audience"),
50
+ } | prompt | llm | StrOutputParser())
51
+ return chain
 
 
 
 
 
 
 
52
 
53
+ def make_rag_chain_without_docs(llm):
54
+ prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
55
+ chain = prompt | llm | StrOutputParser()
56
+ return chain
 
 
57
 
58
+ def make_rag_node(llm,with_docs = True):
 
 
59
 
60
+ if with_docs:
61
+ rag_chain = make_rag_chain(llm)
62
+ else:
63
+ rag_chain = make_rag_chain_without_docs(llm)
64
 
65
+ async def answer_rag(state,config):
66
+ print("---- Answer RAG ----")
67
+ start_time = time.time()
 
68
 
69
+ answer = await rag_chain.ainvoke(state,config)
 
 
70
 
71
+ end_time = time.time()
72
+ elapsed_time = end_time - start_time
73
+ print("RAG elapsed time: ", elapsed_time)
74
+ print("Answer size : ", len(answer))
75
+ # print(f"\n\nAnswer:\n{answer}")
76
+
77
+ return {"answer":answer}
78
 
79
+ return answer_rag
80
 
 
 
 
81
 
 
82
 
83
 
84
  def make_rag_papers_chain(llm):
85
 
86
  prompt = ChatPromptTemplate.from_template(papers_prompt_template)
 
87
  input_documents = {
88
  "context":lambda x : _combine_documents(x["docs"]),
89
  **pass_values(["question","language"])
 
109
  }
110
 
111
  illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
112
+ return illustration_chain
climateqa/engine/chains/chitchat_categorization.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class IntentCategorizer(BaseModel):
11
+ """Analyzing the user message input"""
12
+
13
+ environment: bool = Field(
14
+ description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
15
+ )
16
+
17
+
18
+ def make_chitchat_intent_categorization_chain(llm):
19
+
20
+ openai_functions = [convert_to_openai_function(IntentCategorizer)]
21
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
22
+
23
+ prompt = ChatPromptTemplate.from_messages([
24
+ ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
25
+ ("user", "input: {input}")
26
+ ])
27
+
28
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
29
+ return chain
30
+
31
+
32
+ def make_chitchat_intent_categorization_node(llm):
33
+
34
+ categorization_chain = make_chitchat_intent_categorization_chain(llm)
35
+
36
+ def categorize_message(state):
37
+ output = categorization_chain.invoke({"input": state["user_input"]})
38
+ print(f"\n\nChit chat output intent categorization: {output}\n")
39
+ state["search_graphs_chitchat"] = output["environment"]
40
+ print(f"\n\nChit chat output intent categorization: {state}\n")
41
+ return state
42
+
43
+ return categorize_message
climateqa/engine/chains/graph_retriever.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ from ..reranker import rerank_docs
6
+ from ..graph_retriever import retrieve_graphs # GraphRetriever
7
+ from ...utils import remove_duplicates_keep_highest_score
8
+
9
+
10
+ def divide_into_parts(target, parts):
11
+ # Base value for each part
12
+ base = target // parts
13
+ # Remainder to distribute
14
+ remainder = target % parts
15
+ # List to hold the result
16
+ result = []
17
+
18
+ for i in range(parts):
19
+ if i < remainder:
20
+ # These parts get base value + 1
21
+ result.append(base + 1)
22
+ else:
23
+ # The rest get the base value
24
+ result.append(base)
25
+
26
+ return result
27
+
28
+
29
+ @contextmanager
30
+ def suppress_output():
31
+ # Open a null device
32
+ with open(os.devnull, 'w') as devnull:
33
+ # Store the original stdout and stderr
34
+ old_stdout = sys.stdout
35
+ old_stderr = sys.stderr
36
+ # Redirect stdout and stderr to the null device
37
+ sys.stdout = devnull
38
+ sys.stderr = devnull
39
+ try:
40
+ yield
41
+ finally:
42
+ # Restore stdout and stderr
43
+ sys.stdout = old_stdout
44
+ sys.stderr = old_stderr
45
+
46
+
47
+ def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
48
+
49
+ async def node_retrieve_graphs(state):
50
+ print("---- Retrieving graphs ----")
51
+
52
+ POSSIBLE_SOURCES = ["IEA", "OWID"]
53
+ questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
54
+ # sources_input = state["sources_input"]
55
+ sources_input = ["auto"]
56
+
57
+ auto_mode = "auto" in sources_input
58
+
59
+ # There are several options to get the final top k
60
+ # Option 1 - Get 100 documents by question and rerank by question
61
+ # Option 2 - Get 100/n documents by question and rerank the total
62
+ if rerank_by_question:
63
+ k_by_question = divide_into_parts(k_final,len(questions))
64
+
65
+ docs = []
66
+
67
+ for i,q in enumerate(questions):
68
+
69
+ question = q["question"] if isinstance(q, dict) else q
70
+
71
+ print(f"Subquestion {i}: {question}")
72
+
73
+ # If auto mode, we use all sources
74
+ if auto_mode:
75
+ sources = POSSIBLE_SOURCES
76
+ # Otherwise, we use the config
77
+ else:
78
+ sources = sources_input
79
+
80
+ if any([x in POSSIBLE_SOURCES for x in sources]):
81
+
82
+ sources = [x for x in sources if x in POSSIBLE_SOURCES]
83
+
84
+ # Search the document store using the retriever
85
+ docs_question = await retrieve_graphs(
86
+ query = question,
87
+ vectorstore = vectorstore,
88
+ sources = sources,
89
+ k_total = k_before_reranking,
90
+ threshold = 0.5,
91
+ )
92
+ # docs_question = retriever.get_relevant_documents(question)
93
+
94
+ # Rerank
95
+ if reranker is not None and docs_question!=[]:
96
+ with suppress_output():
97
+ docs_question = rerank_docs(reranker,docs_question,question)
98
+ else:
99
+ # Add a default reranking score
100
+ for doc in docs_question:
101
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
102
+
103
+ # If rerank by question we select the top documents for each question
104
+ if rerank_by_question:
105
+ docs_question = docs_question[:k_by_question[i]]
106
+
107
+ # Add sources used in the metadata
108
+ for doc in docs_question:
109
+ doc.metadata["sources_used"] = sources
110
+
111
+ print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
112
+
113
+ docs.extend(docs_question)
114
+
115
+ else:
116
+ print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
117
+
118
+ # Remove duplicates and keep the duplicate document with the highest reranking score
119
+ docs = remove_duplicates_keep_highest_score(docs)
120
+
121
+ # Sorting the list in descending order by rerank_score
122
+ # Then select the top k
123
+ docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
124
+ docs = docs[:k_final]
125
+
126
+ return {"recommended_content": docs}
127
+
128
+ return node_retrieve_graphs
climateqa/engine/chains/intent_categorization.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class IntentCategorizer(BaseModel):
11
+ """Analyzing the user message input"""
12
+
13
+ language: str = Field(
14
+ description="Find the language of the message input in full words (ex: French, English, Spanish, ...), defaults to English",
15
+ default="English",
16
+ )
17
+ intent: str = Field(
18
+ enum=[
19
+ "ai_impact",
20
+ # "geo_info",
21
+ # "esg",
22
+ "search",
23
+ "chitchat",
24
+ ],
25
+ description="""
26
+ Categorize the user input in one of the following category
27
+ Any question
28
+
29
+ Examples:
30
+ - ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
31
+ - search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
32
+ - chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
33
+ """,
34
+ # - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
35
+ # - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
36
+
37
+ )
38
+
39
+
40
+
41
+ def make_intent_categorization_chain(llm):
42
+
43
+ openai_functions = [convert_to_openai_function(IntentCategorizer)]
44
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
45
+
46
+ prompt = ChatPromptTemplate.from_messages([
47
+ ("system", "You are a helpful assistant, you will analyze, translate and categorize the user input message using the function provided. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
48
+ ("user", "input: {input}")
49
+ ])
50
+
51
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
52
+ return chain
53
+
54
+
55
+ def make_intent_categorization_node(llm):
56
+
57
+ categorization_chain = make_intent_categorization_chain(llm)
58
+
59
+ def categorize_message(state):
60
+ print("---- Categorize_message ----")
61
+
62
+ output = categorization_chain.invoke({"input": state["user_input"]})
63
+ print(f"\n\nOutput intent categorization: {output}\n")
64
+ if "language" not in output: output["language"] = "English"
65
+ output["query"] = state["user_input"]
66
+ return output
67
+
68
+ return categorize_message
69
+
70
+
71
+
72
+
73
+ # SAMPLE_QUESTIONS = [
74
+ # "Est-ce que l'IA a un impact sur l'environnement ?",
75
+ # "Que dit le GIEC sur l'impact de l'IA",
76
+ # "Qui sont les membres du GIEC",
77
+ # "What is the impact of El Nino ?",
78
+ # "Yo",
79
+ # "Hello ça va bien ?",
80
+ # "Par qui as tu été créé ?",
81
+ # "What role do cloud formations play in modulating the Earth's radiative balance, and how are they represented in current climate models?",
82
+ # "Which industries have the highest GHG emissions?",
83
+ # "What are invasive alien species and how do they threaten biodiversity and ecosystems?",
84
+ # "Are human activities causing global warming?",
85
+ # "What is the motivation behind mining the deep seabed?",
86
+ # "Tu peux m'écrire un poème sur le changement climatique ?",
87
+ # "Tu peux m'écrire un poème sur les bonbons ?",
88
+ # "What will be the temperature in 2100 in Strasbourg?",
89
+ # "C'est quoi le lien entre biodiversity and changement climatique ?",
90
+ # ]
climateqa/engine/chains/keywords_extraction.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class KeywordExtraction(BaseModel):
11
+ """
12
+ Analyzing the user query to extract keywords to feed a search engine
13
+ """
14
+
15
+ keywords: List[str] = Field(
16
+ description="""
17
+ Extract the keywords from the user query to feed a search engine as a list
18
+ Avoid adding super specific keywords to prefer general keywords
19
+ Maximum 3 keywords
20
+
21
+ Examples:
22
+ - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
23
+ - "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
24
+ - "Is climate change a hoax" -> ["climate change","hoax"]
25
+ """
26
+ )
27
+
28
+
29
+ def make_keywords_extraction_chain(llm):
30
+
31
+ openai_functions = [convert_to_openai_function(KeywordExtraction)]
32
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
33
+
34
+ prompt = ChatPromptTemplate.from_messages([
35
+ ("system", "You are a helpful assistant"),
36
+ ("user", "input: {input}")
37
+ ])
38
+
39
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
40
+ return chain
climateqa/engine/{prompts.py → chains/prompts.py} RENAMED
@@ -56,7 +56,7 @@ Passages:
56
  {context}
57
 
58
  -----------------------
59
- Question: {question} - Explained to {audience}
60
  Answer in {language} with the passages citations:
61
  """
62
 
@@ -137,7 +137,7 @@ Guidelines:
137
  - If the question is not related to environmental issues, never never answer it. Say it's not your role.
138
  - Make paragraphs by starting new lines to make your answers more readable.
139
 
140
- Question: {question}
141
  Answer in {language}:
142
  """
143
 
@@ -147,4 +147,27 @@ audience_prompts = {
147
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
148
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
149
  "experts": "expert and climate scientists that are not afraid of technical terms",
150
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  {context}
57
 
58
  -----------------------
59
+ Question: {query} - Explained to {audience}
60
  Answer in {language} with the passages citations:
61
  """
62
 
 
137
  - If the question is not related to environmental issues, never never answer it. Say it's not your role.
138
  - Make paragraphs by starting new lines to make your answers more readable.
139
 
140
+ Question: {query}
141
  Answer in {language}:
142
  """
143
 
 
147
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
148
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
149
  "experts": "expert and climate scientists that are not afraid of technical terms",
150
+ }
151
+
152
+
153
+ answer_prompt_graph_template = """
154
+ Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
155
+
156
+ ### Guidelines ###
157
+ - Keep all the graphs that are given to you.
158
+ - NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
159
+ - Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
160
+ - Return a valid JSON output.
161
+
162
+ -----------------------
163
+ User question:
164
+ {query}
165
+
166
+ Graphs and their HTML embedding:
167
+ {recommended_content}
168
+
169
+ -----------------------
170
+ {format_instructions}
171
+
172
+ Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
173
+ """
climateqa/engine/chains/query_transformation.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from langchain_core.pydantic_v1 import BaseModel, Field
4
+ from typing import List
5
+ from typing import Literal
6
+ from langchain.prompts import ChatPromptTemplate
7
+ from langchain_core.utils.function_calling import convert_to_openai_function
8
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
+
10
+
11
+ ROUTING_INDEX = {
12
+ "Vector":["IPCC","IPBES","IPOS"],
13
+ "OpenAlex":["OpenAlex"],
14
+ }
15
+
16
+ POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
17
+
18
+ # Prompt from the original paper https://arxiv.org/pdf/2305.14283
19
+ # Query Rewriting for Retrieval-Augmented Large Language Models
20
+ class QueryDecomposition(BaseModel):
21
+ """
22
+ Decompose the user query into smaller parts to think step by step to answer this question
23
+ Act as a simple planning agent
24
+ """
25
+
26
+ questions: List[str] = Field(
27
+ description="""
28
+ Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
29
+ Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
30
+ - If it's already a standalone and explicit question, just return the reformulated question for the search engine
31
+ - If you need to decompose the question, output a list of maximum 2 to 3 questions
32
+ """
33
+ )
34
+
35
+
36
+ class Location(BaseModel):
37
+ country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
38
+ location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
39
+
40
+ class QueryAnalysis(BaseModel):
41
+ """
42
+ Analyzing the user query to extract topics, sources and date
43
+ Also do query expansion to get alternative search queries
44
+ Also provide simple keywords to feed a search engine
45
+ """
46
+
47
+ # keywords: List[str] = Field(
48
+ # description="""
49
+ # Extract the keywords from the user query to feed a search engine as a list
50
+ # Maximum 3 keywords
51
+
52
+ # Examples:
53
+ # - "What is the impact of deep sea mining ?" -> deep sea mining
54
+ # - "How will El Nino be impacted by climate change" -> el nino;climate change
55
+ # - "Is climate change a hoax" -> climate change;hoax
56
+ # """
57
+ # )
58
+
59
+ # alternative_queries: List[str] = Field(
60
+ # description="""
61
+ # Generate alternative search questions from the user query to feed a search engine
62
+ # """
63
+ # )
64
+
65
+ # step_back_question: str = Field(
66
+ # description="""
67
+ # You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
68
+ # This questions should help you get more context and information about the user query
69
+ # """
70
+ # )
71
+
72
+ sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( #,"OpenAlex"]] = Field(
73
+ ...,
74
+ description="""
75
+ Given a user question choose which documents would be most relevant for answering their question,
76
+ - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
77
+ - IPBES is for questions about biodiversity and nature
78
+ - IPOS is for questions about the ocean and deep sea mining
79
+ """,
80
+ # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
81
+ )
82
+ # topics: List[Literal[
83
+ # "Climate change",
84
+ # "Biodiversity",
85
+ # "Energy",
86
+ # "Decarbonization",
87
+ # "Climate science",
88
+ # "Nature",
89
+ # "Climate policy and justice",
90
+ # "Oceans",
91
+ # "Deep sea mining",
92
+ # "ESG and regulations",
93
+ # "CSRD",
94
+ # ]] = Field(
95
+ # ...,
96
+ # description = """
97
+ # Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
98
+ # """,
99
+ # )
100
+ # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
101
+ # location:Location
102
+
103
+
104
+ def make_query_decomposition_chain(llm):
105
+
106
+ openai_functions = [convert_to_openai_function(QueryDecomposition)]
107
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})
108
+
109
+ prompt = ChatPromptTemplate.from_messages([
110
+ ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
111
+ ("user", "input: {input}")
112
+ ])
113
+
114
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
115
+ return chain
116
+
117
+
118
+ def make_query_rewriter_chain(llm):
119
+
120
+ openai_functions = [convert_to_openai_function(QueryAnalysis)]
121
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
122
+
123
+
124
+
125
+ prompt = ChatPromptTemplate.from_messages([
126
+ ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
127
+ ("user", "input: {input}")
128
+ ])
129
+
130
+
131
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
132
+ return chain
133
+
134
+
135
+ def make_query_transform_node(llm,k_final=15):
136
+
137
+ decomposition_chain = make_query_decomposition_chain(llm)
138
+ rewriter_chain = make_query_rewriter_chain(llm)
139
+
140
+ def transform_query(state):
141
+ print("---- Transform query ----")
142
+
143
+
144
+ if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
145
+ auto_mode = False
146
+ else:
147
+ auto_mode = True
148
+
149
+ sources_input = state.get("sources_input")
150
+ if sources_input is None: sources_input = ROUTING_INDEX["Vector"]
151
+
152
+ new_state = {}
153
+
154
+ # Decomposition
155
+ decomposition_output = decomposition_chain.invoke({"input":state["query"]})
156
+ new_state.update(decomposition_output)
157
+
158
+ # Query Analysis
159
+ questions = []
160
+ for question in new_state["questions"]:
161
+ question_state = {"question":question}
162
+ analysis_output = rewriter_chain.invoke({"input":question})
163
+
164
+ # TODO WARNING llm should always return smthg
165
+ # The case when the llm does not return any sources
166
+ if not analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in analysis_output["sources"]):
167
+ analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
168
+
169
+ question_state.update(analysis_output)
170
+ questions.append(question_state)
171
+
172
+ # Explode the questions into multiple questions with different sources
173
+ new_questions = []
174
+ for q in questions:
175
+ question,sources = q["question"],q["sources"]
176
+
177
+ # If not auto mode we take the configuration
178
+ if not auto_mode:
179
+ sources = sources_input
180
+
181
+ for index,index_sources in ROUTING_INDEX.items():
182
+ selected_sources = list(set(sources).intersection(index_sources))
183
+ if len(selected_sources) > 0:
184
+ new_questions.append({"question":question,"sources":selected_sources,"index":index})
185
+
186
+ # # Add the number of questions to search
187
+ # k_by_question = k_final // len(new_questions)
188
+ # for q in new_questions:
189
+ # q["k"] = k_by_question
190
+
191
+ # new_state["questions"] = new_questions
192
+ # new_state["remaining_questions"] = new_questions
193
+
194
+
195
+ new_state = {
196
+ "remaining_questions":new_questions,
197
+ "n_questions":len(new_questions),
198
+ }
199
+ return new_state
200
+
201
+ return transform_query
climateqa/engine/{reformulation.py → chains/reformulation.py} RENAMED
@@ -3,7 +3,7 @@ from langchain.output_parsers.structured import StructuredOutputParser, Response
3
  from langchain_core.prompts import PromptTemplate
4
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
5
 
6
- from climateqa.engine.prompts import reformulation_prompt_template
7
  from climateqa.engine.utils import pass_values, flatten_dict
8
 
9
 
 
3
  from langchain_core.prompts import PromptTemplate
4
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
5
 
6
+ from climateqa.engine.chains.prompts import reformulation_prompt_template
7
  from climateqa.engine.utils import pass_values, flatten_dict
8
 
9
 
climateqa/engine/chains/retrieve_documents.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ from langchain_core.tools import tool
6
+ from langchain_core.runnables import chain
7
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
+ from langchain_core.runnables import RunnableLambda
9
+
10
+ from ..reranker import rerank_docs
11
+ # from ...knowledge.retriever import ClimateQARetriever
12
+ from ...knowledge.openalex import OpenAlexRetriever
13
+ from .keywords_extraction import make_keywords_extraction_chain
14
+ from ..utils import log_event
15
+ from langchain_core.vectorstores import VectorStore
16
+ from typing import List
17
+ from langchain_core.documents.base import Document
18
+
19
+
20
+
21
+ def divide_into_parts(target, parts):
22
+ # Base value for each part
23
+ base = target // parts
24
+ # Remainder to distribute
25
+ remainder = target % parts
26
+ # List to hold the result
27
+ result = []
28
+
29
+ for i in range(parts):
30
+ if i < remainder:
31
+ # These parts get base value + 1
32
+ result.append(base + 1)
33
+ else:
34
+ # The rest get the base value
35
+ result.append(base)
36
+
37
+ return result
38
+
39
+
40
+ @contextmanager
41
+ def suppress_output():
42
+ # Open a null device
43
+ with open(os.devnull, 'w') as devnull:
44
+ # Store the original stdout and stderr
45
+ old_stdout = sys.stdout
46
+ old_stderr = sys.stderr
47
+ # Redirect stdout and stderr to the null device
48
+ sys.stdout = devnull
49
+ sys.stderr = devnull
50
+ try:
51
+ yield
52
+ finally:
53
+ # Restore stdout and stderr
54
+ sys.stdout = old_stdout
55
+ sys.stderr = old_stderr
56
+
57
+
58
+ @tool
59
+ def query_retriever(question):
60
+ """Just a dummy tool to simulate the retriever query"""
61
+ return question
62
+
63
+ def _add_sources_used_in_metadata(docs,sources,question,index):
64
+ for doc in docs:
65
+ doc.metadata["sources_used"] = sources
66
+ doc.metadata["question_used"] = question
67
+ doc.metadata["index_used"] = index
68
+ return docs
69
+
70
+ def _get_k_summary_by_question(n_questions):
71
+ if n_questions == 0:
72
+ return 0
73
+ elif n_questions == 1:
74
+ return 5
75
+ elif n_questions == 2:
76
+ return 3
77
+ elif n_questions == 3:
78
+ return 2
79
+ else:
80
+ return 1
81
+
82
+ def _get_k_images_by_question(n_questions):
83
+ if n_questions == 0:
84
+ return 0
85
+ elif n_questions == 1:
86
+ return 7
87
+ elif n_questions == 2:
88
+ return 5
89
+ elif n_questions == 3:
90
+ return 2
91
+ else:
92
+ return 1
93
+
94
+ def _add_metadata_and_score(docs: List) -> Document:
95
+ # Add score to metadata
96
+ docs_with_metadata = []
97
+ for i,(doc,score) in enumerate(docs):
98
+ doc.page_content = doc.page_content.replace("\r\n"," ")
99
+ doc.metadata["similarity_score"] = score
100
+ doc.metadata["content"] = doc.page_content
101
+ doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
102
+ # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
103
+ docs_with_metadata.append(doc)
104
+ return docs_with_metadata
105
+
106
+ async def get_IPCC_relevant_documents(
107
+ query: str,
108
+ vectorstore:VectorStore,
109
+ sources:list = ["IPCC","IPBES","IPOS"],
110
+ search_figures:bool = False,
111
+ reports:list = [],
112
+ threshold:float = 0.6,
113
+ k_summary:int = 3,
114
+ k_total:int = 10,
115
+ k_images: int = 5,
116
+ namespace:str = "vectors",
117
+ min_size:int = 200,
118
+ search_only:bool = False,
119
+ ) :
120
+
121
+ # Check if all elements in the list are either IPCC or IPBES
122
+ assert isinstance(sources,list)
123
+ assert sources
124
+ assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
125
+ assert k_total > k_summary, "k_total should be greater than k_summary"
126
+
127
+ # Prepare base search kwargs
128
+ filters = {}
129
+
130
+ if len(reports) > 0:
131
+ filters["short_name"] = {"$in":reports}
132
+ else:
133
+ filters["source"] = { "$in": sources}
134
+
135
+ # INIT
136
+ docs_summaries = []
137
+ docs_full = []
138
+ docs_images = []
139
+
140
+ if search_only:
141
+ # Only search for images if search_only is True
142
+ if search_figures:
143
+ filters_image = {
144
+ **filters,
145
+ "chunk_type":"image"
146
+ }
147
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
148
+ docs_images = _add_metadata_and_score(docs_images)
149
+ else:
150
+ # Regular search flow for text and optionally images
151
+ # Search for k_summary documents in the summaries dataset
152
+ filters_summaries = {
153
+ **filters,
154
+ "chunk_type":"text",
155
+ "report_type": { "$in":["SPM"]},
156
+ }
157
+
158
+ docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
159
+ docs_summaries = [x for x in docs_summaries if x[1] > threshold]
160
+
161
+ # Search for k_total - k_summary documents in the full reports dataset
162
+ filters_full = {
163
+ **filters,
164
+ "chunk_type":"text",
165
+ "report_type": { "$nin":["SPM"]},
166
+ }
167
+ k_full = k_total - len(docs_summaries)
168
+ docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
169
+
170
+ if search_figures:
171
+ # Images
172
+ filters_image = {
173
+ **filters,
174
+ "chunk_type":"image"
175
+ }
176
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
177
+
178
+ docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
179
+
180
+ # Filter if length are below threshold
181
+ docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
182
+ docs_full = [x for x in docs_full if len(x.page_content) > min_size]
183
+
184
+ return {
185
+ "docs_summaries" : docs_summaries,
186
+ "docs_full" : docs_full,
187
+ "docs_images" : docs_images,
188
+ }
189
+
190
+
191
+
192
+ # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
193
+ # @chain
194
+ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
195
+ """
196
+ Retrieve and rerank documents based on the current question in the state.
197
+
198
+ Args:
199
+ state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
200
+ config (dict): Configuration settings for logging and other purposes.
201
+ vectorstore (object): The vector store used to retrieve relevant documents.
202
+ reranker (object): The reranker used to rerank the retrieved documents.
203
+ llm (object): The language model used for processing.
204
+ rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
205
+ k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
206
+ k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
207
+ k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
208
+ k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
209
+ Returns:
210
+ dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
211
+ """
212
+ print("---- Retrieve documents ----")
213
+
214
+ # Get the documents from the state
215
+ if "documents" in state and state["documents"] is not None:
216
+ docs = state["documents"]
217
+ else:
218
+ docs = []
219
+ # Get the related_content from the state
220
+ if "related_content" in state and state["related_content"] is not None:
221
+ related_content = state["related_content"]
222
+ else:
223
+ related_content = []
224
+
225
+ search_figures = "IPCC figures" in state["relevant_content_sources"]
226
+ search_only = state["search_only"]
227
+
228
+ # Get the current question
229
+ current_question = state["remaining_questions"][0]
230
+ remaining_questions = state["remaining_questions"][1:]
231
+
232
+ k_by_question = k_final // state["n_questions"]
233
+ k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
234
+ k_images_by_question = _get_k_images_by_question(state["n_questions"])
235
+
236
+ sources = current_question["sources"]
237
+ question = current_question["question"]
238
+ index = current_question["index"]
239
+
240
+ print(f"Retrieve documents for question: {question}")
241
+ await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
242
+
243
+
244
+ if index == "Vector": # always true for now
245
+ docs_question_dict = await get_IPCC_relevant_documents(
246
+ query = question,
247
+ vectorstore=vectorstore,
248
+ search_figures = search_figures,
249
+ sources = sources,
250
+ min_size = 200,
251
+ k_summary = k_summary_by_question,
252
+ k_total = k_before_reranking,
253
+ k_images = k_images_by_question,
254
+ threshold = 0.5,
255
+ search_only = search_only,
256
+ )
257
+
258
+
259
+ # Rerank
260
+ if reranker is not None:
261
+ with suppress_output():
262
+ docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
263
+ docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
264
+ docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
265
+ if rerank_by_question:
266
+ docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
267
+ docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
268
+ docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
269
+ else:
270
+ docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
271
+ # Add a default reranking score
272
+ for doc in docs_question:
273
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
274
+
275
+ docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
276
+ docs_question = docs_question[:k_by_question]
277
+ images_question = docs_question_images_reranked[:k_images]
278
+
279
+ if reranker is not None and rerank_by_question:
280
+ docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
281
+
282
+ # Add sources used in the metadata
283
+ docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
284
+ images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
285
+
286
+ # Add to the list of docs
287
+ docs.extend(docs_question)
288
+ related_content.extend(images_question)
289
+ new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
290
+ return new_state
291
+
292
+
293
+
294
+ def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
295
+ @chain
296
+ async def retrieve_docs(state, config):
297
+ state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
298
+ return state
299
+
300
+ return retrieve_docs
301
+
302
+
303
+
climateqa/engine/chains/retrieve_papers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.keywords import make_keywords_chain
2
+ from climateqa.engine.llm import get_llm
3
+ from climateqa.knowledge.openalex import OpenAlex
4
+ from climateqa.engine.chains.answer_rag import make_rag_papers_chain
5
+ from front.utils import make_html_papers
6
+ from climateqa.engine.reranker import get_reranker
7
+
8
+ oa = OpenAlex()
9
+
10
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
11
+ reranker = get_reranker("nano")
12
+
13
+
14
+ papers_cols_widths = {
15
+ "id":100,
16
+ "title":300,
17
+ "doi":100,
18
+ "publication_year":100,
19
+ "abstract":500,
20
+ "is_oa":50,
21
+ }
22
+
23
+ papers_cols = list(papers_cols_widths.keys())
24
+ papers_cols_widths = list(papers_cols_widths.values())
25
+
26
+
27
+
28
+ def generate_keywords(query):
29
+ chain = make_keywords_chain(llm)
30
+ keywords = chain.invoke(query)
31
+ keywords = " AND ".join(keywords["keywords"])
32
+ return keywords
33
+
34
+
35
+ async def find_papers(query,after, relevant_content_sources, reranker= reranker):
36
+ if "OpenAlex" in relevant_content_sources:
37
+ summary = ""
38
+ keywords = generate_keywords(query)
39
+ df_works = oa.search(keywords,after = after)
40
+
41
+ print(f"Found {len(df_works)} papers")
42
+
43
+ if not df_works.empty:
44
+ df_works = df_works.dropna(subset=["abstract"])
45
+ df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
46
+ df_works = oa.rerank(query,df_works,reranker)
47
+ df_works = df_works.sort_values("rerank_score",ascending=False)
48
+ docs_html = []
49
+ for i in range(10):
50
+ docs_html.append(make_html_papers(df_works, i))
51
+ docs_html = "".join(docs_html)
52
+ G = oa.make_network(df_works)
53
+
54
+ height = "750px"
55
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
56
+ network_html = network.generate_html()
57
+
58
+ network_html = network_html.replace("'", "\"")
59
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
60
+ network_html = network_html + css_to_inject
61
+
62
+
63
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
64
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
65
+ allow-scripts allow-same-origin allow-popups
66
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
67
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
68
+
69
+
70
+ docs = df_works["content"].head(10).tolist()
71
+
72
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
73
+ df_works["doc"] = df_works["doc"] + 1
74
+ df_works = df_works[papers_cols]
75
+
76
+ yield docs_html, network_html, summary
77
+
78
+ chain = make_rag_papers_chain(llm)
79
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
80
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
81
+
82
+ async for op in result:
83
+
84
+ op = op.ops[0]
85
+
86
+ if op['path'] == path_answer: # reforulated question
87
+ new_token = op['value'] # str
88
+ summary += new_token
89
+ else:
90
+ continue
91
+ yield docs_html, network_html, summary
92
+ else :
93
+ print("No papers found")
94
+ else :
95
+ yield "","", ""
climateqa/engine/chains/retriever.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # import os
3
+ # from contextlib import contextmanager
4
+
5
+ # from ..reranker import rerank_docs
6
+ # from ...knowledge.retriever import ClimateQARetriever
7
+
8
+
9
+
10
+
11
+ # def divide_into_parts(target, parts):
12
+ # # Base value for each part
13
+ # base = target // parts
14
+ # # Remainder to distribute
15
+ # remainder = target % parts
16
+ # # List to hold the result
17
+ # result = []
18
+
19
+ # for i in range(parts):
20
+ # if i < remainder:
21
+ # # These parts get base value + 1
22
+ # result.append(base + 1)
23
+ # else:
24
+ # # The rest get the base value
25
+ # result.append(base)
26
+
27
+ # return result
28
+
29
+
30
+ # @contextmanager
31
+ # def suppress_output():
32
+ # # Open a null device
33
+ # with open(os.devnull, 'w') as devnull:
34
+ # # Store the original stdout and stderr
35
+ # old_stdout = sys.stdout
36
+ # old_stderr = sys.stderr
37
+ # # Redirect stdout and stderr to the null device
38
+ # sys.stdout = devnull
39
+ # sys.stderr = devnull
40
+ # try:
41
+ # yield
42
+ # finally:
43
+ # # Restore stdout and stderr
44
+ # sys.stdout = old_stdout
45
+ # sys.stderr = old_stderr
46
+
47
+
48
+
49
+ # def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
50
+
51
+ # def retrieve_documents(state):
52
+
53
+ # POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
54
+ # questions = state["questions"]
55
+
56
+ # # Use sources from the user input or from the LLM detection
57
+ # if "sources_input" not in state or state["sources_input"] is None:
58
+ # sources_input = ["auto"]
59
+ # else:
60
+ # sources_input = state["sources_input"]
61
+ # auto_mode = "auto" in sources_input
62
+
63
+ # # There are several options to get the final top k
64
+ # # Option 1 - Get 100 documents by question and rerank by question
65
+ # # Option 2 - Get 100/n documents by question and rerank the total
66
+ # if rerank_by_question:
67
+ # k_by_question = divide_into_parts(k_final,len(questions))
68
+
69
+ # docs = []
70
+
71
+ # for i,q in enumerate(questions):
72
+
73
+ # sources = q["sources"]
74
+ # question = q["question"]
75
+
76
+ # # If auto mode, we use the sources detected by the LLM
77
+ # if auto_mode:
78
+ # sources = [x for x in sources if x in POSSIBLE_SOURCES]
79
+
80
+ # # Otherwise, we use the config
81
+ # else:
82
+ # sources = sources_input
83
+
84
+ # # Search the document store using the retriever
85
+ # # Configure high top k for further reranking step
86
+ # retriever = ClimateQARetriever(
87
+ # vectorstore=vectorstore,
88
+ # sources = sources,
89
+ # # reports = ias_reports,
90
+ # min_size = 200,
91
+ # k_summary = k_summary,
92
+ # k_total = k_before_reranking,
93
+ # threshold = 0.5,
94
+ # )
95
+ # docs_question = retriever.get_relevant_documents(question)
96
+
97
+ # # Rerank
98
+ # if reranker is not None:
99
+ # with suppress_output():
100
+ # docs_question = rerank_docs(reranker,docs_question,question)
101
+ # else:
102
+ # # Add a default reranking score
103
+ # for doc in docs_question:
104
+ # doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
105
+
106
+ # # If rerank by question we select the top documents for each question
107
+ # if rerank_by_question:
108
+ # docs_question = docs_question[:k_by_question[i]]
109
+
110
+ # # Add sources used in the metadata
111
+ # for doc in docs_question:
112
+ # doc.metadata["sources_used"] = sources
113
+
114
+ # # Add to the list of docs
115
+ # docs.extend(docs_question)
116
+
117
+ # # Sorting the list in descending order by rerank_score
118
+ # # Then select the top k
119
+ # docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
120
+ # docs = docs[:k_final]
121
+
122
+ # new_state = {"documents":docs}
123
+ # return new_state
124
+
125
+ # return retrieve_documents
126
+
climateqa/engine/chains/sample_router.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # from typing import List
3
+ # from typing import Literal
4
+ # from langchain.prompts import ChatPromptTemplate
5
+ # from langchain_core.utils.function_calling import convert_to_openai_function
6
+ # from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
7
+
8
+ # # https://livingdatalab.com/posts/2023-11-05-openai-function-calling-with-langchain.html
9
+
10
+ # class Location(BaseModel):
11
+ # country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
12
+ # location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
13
+
14
+ # class QueryAnalysis(BaseModel):
15
+ # """Analyzing the user query"""
16
+
17
+ # language: str = Field(
18
+ # description="Find the language of the query in full words (ex: French, English, Spanish, ...), defaults to English"
19
+ # )
20
+ # intent: str = Field(
21
+ # enum=[
22
+ # "Environmental impacts of AI",
23
+ # "Geolocated info about climate change",
24
+ # "Climate change",
25
+ # "Biodiversity",
26
+ # "Deep sea mining",
27
+ # "Chitchat",
28
+ # ],
29
+ # description="""
30
+ # Categorize the user query in one of the following category,
31
+
32
+ # Examples:
33
+ # - Geolocated info about climate change: "What will be the temperature in Marseille in 2050"
34
+ # - Climate change: "What is radiative forcing", "How much will
35
+ # """,
36
+ # )
37
+ # sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field(
38
+ # ...,
39
+ # description="""
40
+ # Given a user question choose which documents would be most relevant for answering their question,
41
+ # - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
42
+ # - IPBES is for questions about biodiversity and nature
43
+ # - IPOS is for questions about the ocean and deep sea mining
44
+
45
+ # """,
46
+ # )
47
+ # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
48
+ # location:Location
49
+ # # query: str = Field(
50
+ # # description = """
51
+ # # Translate to english and reformulate the following user message to be a short standalone question, in the context of an educational discussion about climate change.
52
+ # # The reformulated question will used in a search engine
53
+ # # By default, assume that the user is asking information about the last century,
54
+ # # Use the following examples
55
+
56
+ # # ### Examples:
57
+ # # La technologie nous sauvera-t-elle ? -> Can technology help humanity mitigate the effects of climate change?
58
+ # # what are our reserves in fossil fuel? -> What are the current reserves of fossil fuels and how long will they last?
59
+ # # what are the main causes of climate change? -> What are the main causes of climate change in the last century?
60
+
61
+ # # Question in English:
62
+ # # """
63
+ # # )
64
+
65
+ # openai_functions = [convert_to_openai_function(QueryAnalysis)]
66
+ # llm2 = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
climateqa/engine/chains/set_defaults.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def set_defaults(state):
2
+ print("---- Setting defaults ----")
3
+
4
+ if not state["audience"] or state["audience"] is None:
5
+ state.update({"audience": "experts"})
6
+
7
+ sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
8
+ state.update({"sources_input": sources_input})
9
+
10
+ # if not state["sources_input"] or state["sources_input"] is None:
11
+ # state.update({"sources_input": ["auto"]})
12
+
13
+ return state
climateqa/engine/chains/translation.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class Translation(BaseModel):
11
+ """Analyzing the user message input"""
12
+
13
+ translation: str = Field(
14
+ description="Translate the message input to English",
15
+ )
16
+
17
+
18
+ def make_translation_chain(llm):
19
+
20
+ openai_functions = [convert_to_openai_function(Translation)]
21
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"Translation"})
22
+
23
+ prompt = ChatPromptTemplate.from_messages([
24
+ ("system", "You are a helpful assistant, you will translate the user input message to English using the function provided"),
25
+ ("user", "input: {input}")
26
+ ])
27
+
28
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
29
+ return chain
30
+
31
+
32
+ def make_translation_node(llm):
33
+ translation_chain = make_translation_chain(llm)
34
+
35
+ def translate_query(state):
36
+ print("---- Translate query ----")
37
+
38
+ user_input = state["user_input"]
39
+ translation = translation_chain.invoke({"input":user_input})
40
+ return {"query":translation["translation"]}
41
+
42
+ return translate_query
climateqa/engine/embeddings.py CHANGED
@@ -2,7 +2,7 @@
2
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
 
5
- def get_embeddings_function(version = "v1.2"):
6
 
7
  if version == "v1.2":
8
 
@@ -10,12 +10,12 @@ def get_embeddings_function(version = "v1.2"):
10
  # Best embedding model at a reasonable size at the moment (2023-11-22)
11
 
12
  model_name = "BAAI/bge-base-en-v1.5"
13
- encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
14
  print("Loading embeddings model: ", model_name)
15
  embeddings_function = HuggingFaceBgeEmbeddings(
16
  model_name=model_name,
17
  encode_kwargs=encode_kwargs,
18
- query_instruction="Represent this sentence for searching relevant passages: "
19
  )
20
 
21
  else:
@@ -23,3 +23,6 @@ def get_embeddings_function(version = "v1.2"):
23
  embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
24
 
25
  return embeddings_function
 
 
 
 
2
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
 
5
+ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this sentence for searching relevant passages: "):
6
 
7
  if version == "v1.2":
8
 
 
10
  # Best embedding model at a reasonable size at the moment (2023-11-22)
11
 
12
  model_name = "BAAI/bge-base-en-v1.5"
13
+ encode_kwargs = {'normalize_embeddings': True,"show_progress_bar":False} # set True to compute cosine similarity
14
  print("Loading embeddings model: ", model_name)
15
  embeddings_function = HuggingFaceBgeEmbeddings(
16
  model_name=model_name,
17
  encode_kwargs=encode_kwargs,
18
+ query_instruction=query_instruction,
19
  )
20
 
21
  else:
 
23
  embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
24
 
25
  return embeddings_function
26
+
27
+
28
+
climateqa/engine/graph.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ from langchain.schema import Document
6
+ from langgraph.graph import END, StateGraph
7
+ from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
8
+
9
+ from typing_extensions import TypedDict
10
+ from typing import List, Dict
11
+
12
+ from IPython.display import display, HTML, Image
13
+
14
+ from .chains.answer_chitchat import make_chitchat_node
15
+ from .chains.answer_ai_impact import make_ai_impact_node
16
+ from .chains.query_transformation import make_query_transform_node
17
+ from .chains.translation import make_translation_node
18
+ from .chains.intent_categorization import make_intent_categorization_node
19
+ from .chains.retrieve_documents import make_retriever_node
20
+ from .chains.answer_rag import make_rag_node
21
+ from .chains.graph_retriever import make_graph_retriever_node
22
+ from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
23
+ # from .chains.set_defaults import set_defaults
24
+
25
+ class GraphState(TypedDict):
26
+ """
27
+ Represents the state of our graph.
28
+ """
29
+ user_input : str
30
+ language : str
31
+ intent : str
32
+ search_graphs_chitchat : bool
33
+ query: str
34
+ remaining_questions : List[dict]
35
+ n_questions : int
36
+ answer: str
37
+ audience: str = "experts"
38
+ sources_input: List[str] = ["IPCC","IPBES"]
39
+ relevant_content_sources: List[str] = ["IPCC figures"]
40
+ sources_auto: bool = True
41
+ min_year: int = 1960
42
+ max_year: int = None
43
+ documents: List[Document]
44
+ related_contents : Dict[str,Document]
45
+ recommended_content : List[Document]
46
+ search_only : bool = False
47
+
48
+ def search(state): #TODO
49
+ return state
50
+
51
+ def answer_search(state):#TODO
52
+ return state
53
+
54
+ def route_intent(state):
55
+ intent = state["intent"]
56
+ if intent in ["chitchat","esg"]:
57
+ return "answer_chitchat"
58
+ # elif intent == "ai_impact":
59
+ # return "answer_ai_impact"
60
+ else:
61
+ # Search route
62
+ return "search"
63
+
64
+ def chitchat_route_intent(state):
65
+ intent = state["search_graphs_chitchat"]
66
+ if intent is True:
67
+ return "retrieve_graphs_chitchat"
68
+ elif intent is False:
69
+ return END
70
+
71
+ def route_translation(state):
72
+ if state["language"].lower() == "english":
73
+ return "transform_query"
74
+ else:
75
+ return "translate_query"
76
+
77
+ def route_based_on_relevant_docs(state,threshold_docs=0.2):
78
+ docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
79
+ if len(docs) > 0:
80
+ return "answer_rag"
81
+ else:
82
+ return "answer_rag_no_docs"
83
+
84
+ def route_retrieve_documents(state):
85
+ if state["search_only"] :
86
+ return END
87
+ elif len(state["remaining_questions"]) > 0:
88
+ return "retrieve_documents"
89
+ else:
90
+ return "answer_search"
91
+
92
+ def make_id_dict(values):
93
+ return {k:k for k in values}
94
+
95
+ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2):
96
+
97
+ workflow = StateGraph(GraphState)
98
+
99
+ # Define the node functions
100
+ categorize_intent = make_intent_categorization_node(llm)
101
+ transform_query = make_query_transform_node(llm)
102
+ translate_query = make_translation_node(llm)
103
+ answer_chitchat = make_chitchat_node(llm)
104
+ answer_ai_impact = make_ai_impact_node(llm)
105
+ retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm)
106
+ retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
107
+ answer_rag = make_rag_node(llm, with_docs=True)
108
+ answer_rag_no_docs = make_rag_node(llm, with_docs=False)
109
+ chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
110
+
111
+ # Define the nodes
112
+ # workflow.add_node("set_defaults", set_defaults)
113
+ workflow.add_node("categorize_intent", categorize_intent)
114
+ workflow.add_node("search", search)
115
+ workflow.add_node("answer_search", answer_search)
116
+ workflow.add_node("transform_query", transform_query)
117
+ workflow.add_node("translate_query", translate_query)
118
+ workflow.add_node("answer_chitchat", answer_chitchat)
119
+ workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
120
+ workflow.add_node("retrieve_graphs", retrieve_graphs)
121
+ workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
122
+ workflow.add_node("retrieve_documents", retrieve_documents)
123
+ workflow.add_node("answer_rag", answer_rag)
124
+ workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
125
+
126
+ # Entry point
127
+ workflow.set_entry_point("categorize_intent")
128
+
129
+ # CONDITIONAL EDGES
130
+ workflow.add_conditional_edges(
131
+ "categorize_intent",
132
+ route_intent,
133
+ make_id_dict(["answer_chitchat","search"])
134
+ )
135
+
136
+ workflow.add_conditional_edges(
137
+ "chitchat_categorize_intent",
138
+ chitchat_route_intent,
139
+ make_id_dict(["retrieve_graphs_chitchat", END])
140
+ )
141
+
142
+ workflow.add_conditional_edges(
143
+ "search",
144
+ route_translation,
145
+ make_id_dict(["translate_query","transform_query"])
146
+ )
147
+ workflow.add_conditional_edges(
148
+ "retrieve_documents",
149
+ # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
150
+ route_retrieve_documents,
151
+ make_id_dict([END,"retrieve_documents","answer_search"])
152
+ )
153
+
154
+ workflow.add_conditional_edges(
155
+ "answer_search",
156
+ lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
157
+ make_id_dict(["answer_rag","answer_rag_no_docs"])
158
+ )
159
+ workflow.add_conditional_edges(
160
+ "transform_query",
161
+ lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END,
162
+ make_id_dict(["retrieve_graphs", END])
163
+ )
164
+
165
+ # Define the edges
166
+ workflow.add_edge("translate_query", "transform_query")
167
+ workflow.add_edge("transform_query", "retrieve_documents")
168
+
169
+ workflow.add_edge("retrieve_graphs", END)
170
+ workflow.add_edge("answer_rag", END)
171
+ workflow.add_edge("answer_rag_no_docs", END)
172
+ workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
173
+
174
+
175
+ # Compile
176
+ app = workflow.compile()
177
+ return app
178
+
179
+
180
+
181
+
182
+ def display_graph(app):
183
+
184
+ display(
185
+ Image(
186
+ app.get_graph(xray = True).draw_mermaid_png(
187
+ draw_method=MermaidDrawMethod.API,
188
+ )
189
+ )
190
+ )
climateqa/engine/graph_retriever.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.retrievers import BaseRetriever
2
+ from langchain_core.documents.base import Document
3
+ from langchain_core.vectorstores import VectorStore
4
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
5
+
6
+ from typing import List
7
+
8
+ # class GraphRetriever(BaseRetriever):
9
+ # vectorstore:VectorStore
10
+ # sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
11
+ # threshold:float = 0.5
12
+ # k_total:int = 10
13
+
14
+ # def _get_relevant_documents(
15
+ # self, query: str, *, run_manager: CallbackManagerForRetrieverRun
16
+ # ) -> List[Document]:
17
+
18
+ # # Check if all elements in the list are IEA or OWID
19
+ # assert isinstance(self.sources,list)
20
+ # assert self.sources
21
+ # assert any([x in ["OWID"] for x in self.sources])
22
+
23
+ # # Prepare base search kwargs
24
+ # filters = {}
25
+
26
+ # filters["source"] = {"$in": self.sources}
27
+
28
+ # docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
29
+
30
+ # # Filter if scores are below threshold
31
+ # docs = [x for x in docs if x[1] > self.threshold]
32
+
33
+ # # Remove duplicate documents
34
+ # unique_docs = []
35
+ # seen_docs = []
36
+ # for i, doc in enumerate(docs):
37
+ # if doc[0].page_content not in seen_docs:
38
+ # unique_docs.append(doc)
39
+ # seen_docs.append(doc[0].page_content)
40
+
41
+ # # Add score to metadata
42
+ # results = []
43
+ # for i,(doc,score) in enumerate(unique_docs):
44
+ # doc.metadata["similarity_score"] = score
45
+ # doc.metadata["content"] = doc.page_content
46
+ # results.append(doc)
47
+
48
+ # return results
49
+
50
+ async def retrieve_graphs(
51
+ query: str,
52
+ vectorstore:VectorStore,
53
+ sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
54
+ threshold:float = 0.5,
55
+ k_total:int = 10,
56
+ )-> List[Document]:
57
+
58
+ # Check if all elements in the list are IEA or OWID
59
+ assert isinstance(sources,list)
60
+ assert sources
61
+ assert any([x in ["OWID"] for x in sources])
62
+
63
+ # Prepare base search kwargs
64
+ filters = {}
65
+
66
+ filters["source"] = {"$in": sources}
67
+
68
+ docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
69
+
70
+ # Filter if scores are below threshold
71
+ docs = [x for x in docs if x[1] > threshold]
72
+
73
+ # Remove duplicate documents
74
+ unique_docs = []
75
+ seen_docs = []
76
+ for i, doc in enumerate(docs):
77
+ if doc[0].page_content not in seen_docs:
78
+ unique_docs.append(doc)
79
+ seen_docs.append(doc[0].page_content)
80
+
81
+ # Add score to metadata
82
+ results = []
83
+ for i,(doc,score) in enumerate(unique_docs):
84
+ doc.metadata["similarity_score"] = score
85
+ doc.metadata["content"] = doc.page_content
86
+ results.append(doc)
87
+
88
+ return results
climateqa/engine/keywords.py CHANGED
@@ -11,10 +11,12 @@ class KeywordsOutput(BaseModel):
11
 
12
  keywords: list = Field(
13
  description="""
14
- Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
 
15
 
16
  Example:
17
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
 
18
  - "How will El Nino be impacted by climate change" -> ["el nino"]
19
  - "Is climate change a hoax" -> [Climate change","hoax"]
20
  """
 
11
 
12
  keywords: list = Field(
13
  description="""
14
+ Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers. Answer only with English keywords.
15
+ Do not use special characters or accents.
16
 
17
  Example:
18
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
19
+ - "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
20
  - "How will El Nino be impacted by climate change" -> ["el nino"]
21
  - "Is climate change a hoax" -> [Climate change","hoax"]
22
  """
climateqa/engine/llm/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
 
3
 
4
 
5
  def get_llm(provider="openai",**kwargs):
@@ -8,6 +9,8 @@ def get_llm(provider="openai",**kwargs):
8
  return get_openai_llm(**kwargs)
9
  elif provider == "azure":
10
  return get_azure_llm(**kwargs)
 
 
11
  else:
12
  raise ValueError(f"Unknown provider: {provider}")
13
 
 
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
3
+ from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
4
 
5
 
6
  def get_llm(provider="openai",**kwargs):
 
9
  return get_openai_llm(**kwargs)
10
  elif provider == "azure":
11
  return get_azure_llm(**kwargs)
12
+ elif provider == "ollama":
13
+ return get_ollama_llm(**kwargs)
14
  else:
15
  raise ValueError(f"Unknown provider: {provider}")
16
 
climateqa/engine/llm/ollama.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+
3
+ from langchain_community.llms import Ollama
4
+
5
+ def get_llm(model="llama3", **kwargs):
6
+ return Ollama(model=model, **kwargs)
climateqa/engine/llm/openai.py CHANGED
@@ -7,7 +7,7 @@ try:
7
  except Exception:
8
  pass
9
 
10
- def get_llm(model="gpt-3.5-turbo-0125",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
11
 
12
  llm = ChatOpenAI(
13
  model=model,
 
7
  except Exception:
8
  pass
9
 
10
+ def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
11
 
12
  llm = ChatOpenAI(
13
  model=model,
climateqa/engine/reranker.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from scipy.special import expit, logit
4
+ from rerankers import Reranker
5
+ from sentence_transformers import CrossEncoder
6
+
7
+ load_dotenv()
8
+
9
+ def get_reranker(model = "nano", cohere_api_key = None):
10
+
11
+ assert model in ["nano","tiny","small","large", "jina"]
12
+
13
+ if model == "nano":
14
+ reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
15
+ elif model == "tiny":
16
+ reranker = Reranker('ms-marco-MiniLM-L-12-v2', model_type='flashrank')
17
+ elif model == "small":
18
+ reranker = Reranker("mixedbread-ai/mxbai-rerank-xsmall-v1", model_type='cross-encoder')
19
+ elif model == "large":
20
+ if cohere_api_key is None:
21
+ cohere_api_key = os.environ["COHERE_API_KEY"]
22
+ reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
23
+ elif model == "jina":
24
+ # Reached token quota so does not work
25
+ reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
26
+ # marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
27
+ # reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
28
+ return reranker
29
+
30
+
31
+
32
+ def rerank_docs(reranker,docs,query):
33
+ if docs == []:
34
+ return []
35
+
36
+ # Get a list of texts from langchain docs
37
+ input_docs = [x.page_content for x in docs]
38
+
39
+ # Rerank using rerankers library
40
+ results = reranker.rank(query=query, docs=input_docs)
41
+
42
+ # Prepare langchain list of docs
43
+ docs_reranked = []
44
+ for result in results.results:
45
+ doc_id = result.document.doc_id
46
+ doc = docs[doc_id]
47
+ doc.metadata["reranking_score"] = result.score
48
+ doc.metadata["query_used_for_retrieval"] = query
49
+ docs_reranked.append(doc)
50
+ return docs_reranked
climateqa/engine/retriever.py DELETED
@@ -1,163 +0,0 @@
1
- # https://github.com/langchain-ai/langchain/issues/8623
2
-
3
- import pandas as pd
4
-
5
- from langchain_core.retrievers import BaseRetriever
6
- from langchain_core.vectorstores import VectorStoreRetriever
7
- from langchain_core.documents.base import Document
8
- from langchain_core.vectorstores import VectorStore
9
- from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10
-
11
- from typing import List
12
- from pydantic import Field
13
-
14
- class ClimateQARetriever(BaseRetriever):
15
- vectorstore:VectorStore
16
- sources:list = ["IPCC","IPBES","IPOS"]
17
- reports:list = []
18
- threshold:float = 0.6
19
- k_summary:int = 3
20
- k_total:int = 10
21
- namespace:str = "vectors",
22
- min_size:int = 200,
23
-
24
-
25
- def _get_relevant_documents(
26
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
27
- ) -> List[Document]:
28
-
29
- # Check if all elements in the list are either IPCC or IPBES
30
- assert isinstance(self.sources,list)
31
- assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
32
- assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
33
-
34
- # Prepare base search kwargs
35
- filters = {}
36
-
37
- if len(self.reports) > 0:
38
- filters["short_name"] = {"$in":self.reports}
39
- else:
40
- filters["source"] = { "$in":self.sources}
41
-
42
- # Search for k_summary documents in the summaries dataset
43
- filters_summaries = {
44
- **filters,
45
- "report_type": { "$in":["SPM"]},
46
- }
47
-
48
- docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
49
- docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
50
-
51
- # Search for k_total - k_summary documents in the full reports dataset
52
- filters_full = {
53
- **filters,
54
- "report_type": { "$nin":["SPM"]},
55
- }
56
- k_full = self.k_total - len(docs_summaries)
57
- docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
58
-
59
- # Concatenate documents
60
- docs = docs_summaries + docs_full
61
-
62
- # Filter if scores are below threshold
63
- docs = [x for x in docs if len(x[0].page_content) > self.min_size]
64
- # docs = [x for x in docs if x[1] > self.threshold]
65
-
66
- # Add score to metadata
67
- results = []
68
- for i,(doc,score) in enumerate(docs):
69
- doc.metadata["similarity_score"] = score
70
- doc.metadata["content"] = doc.page_content
71
- doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
72
- # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
73
- results.append(doc)
74
-
75
- # Sort by score
76
- # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
77
-
78
- return results
79
-
80
-
81
-
82
-
83
- # def filter_summaries(df,k_summary = 3,k_total = 10):
84
- # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
85
-
86
- # # # Filter by source
87
- # # if source == "IPCC":
88
- # # df = df.loc[df["source"]=="IPCC"]
89
- # # elif source == "IPBES":
90
- # # df = df.loc[df["source"]=="IPBES"]
91
- # # else:
92
- # # pass
93
-
94
- # # Separate summaries and full reports
95
- # df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
96
- # df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
97
-
98
- # # Find passages from summaries dataset
99
- # passages_summaries = df_summaries.head(k_summary)
100
-
101
- # # Find passages from full reports dataset
102
- # passages_fullreports = df_full.head(k_total - len(passages_summaries))
103
-
104
- # # Concatenate passages
105
- # passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
106
- # return passages
107
-
108
-
109
-
110
-
111
- # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
112
- # assert max_k > k_total
113
-
114
- # validated_sources = ["IPCC","IPBES"]
115
- # sources = [x for x in sources if x in validated_sources]
116
- # filters = {
117
- # "source": { "$in": sources },
118
- # }
119
- # print(filters)
120
-
121
- # # Retrieve documents
122
- # docs = retriever.retrieve(query,top_k = max_k,filters = filters)
123
-
124
- # # Filter by score
125
- # docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
126
-
127
- # if len(docs) == 0:
128
- # return []
129
- # res = pd.DataFrame(docs)
130
- # passages_df = filter_summaries(res,k_summary,k_total)
131
- # if as_dict:
132
- # contents = passages_df["content"].tolist()
133
- # meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
134
- # passages = []
135
- # for i in range(len(contents)):
136
- # passages.append({"content":contents[i],"meta":meta[i]})
137
- # return passages
138
- # else:
139
- # return passages_df
140
-
141
-
142
-
143
- # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
144
-
145
-
146
- # print("hellooooo")
147
-
148
- # # Reformulate queries
149
- # reformulated_query,language = reformulate(query)
150
-
151
- # print(reformulated_query)
152
-
153
- # # Retrieve documents
154
- # passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
155
- # response = {
156
- # "query":query,
157
- # "reformulated_query":reformulated_query,
158
- # "language":language,
159
- # "sources":passages,
160
- # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
161
- # }
162
- # return response
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/utils.py CHANGED
@@ -1,8 +1,15 @@
1
  from operator import itemgetter
2
  from typing import Any, Dict, Iterable, Tuple
 
3
  from langchain_core.runnables import RunnablePassthrough
4
 
5
 
 
 
 
 
 
 
6
  def pass_values(x):
7
  if not isinstance(x, list):
8
  x = [x]
@@ -67,3 +74,13 @@ def flatten_dict(
67
  """
68
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
69
  return flat_dict
 
 
 
 
 
 
 
 
 
 
 
1
  from operator import itemgetter
2
  from typing import Any, Dict, Iterable, Tuple
3
+ import tiktoken
4
  from langchain_core.runnables import RunnablePassthrough
5
 
6
 
7
+ def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
8
+ encoding = tiktoken.get_encoding(encoding_name)
9
+ num_tokens = len(encoding.encode(string))
10
+ return num_tokens
11
+
12
+
13
  def pass_values(x):
14
  if not isinstance(x, list):
15
  x = [x]
 
74
  """
75
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
76
  return flat_dict
77
+
78
+
79
+
80
+ async def log_event(info,name,config):
81
+ """Helper function that will run a dummy chain with the given info
82
+ The astream_event function will catch this chain and stream the dict info to the logger
83
+ """
84
+
85
+ chain = RunnablePassthrough().with_config(run_name=name)
86
+ _ = await chain.ainvoke(info,config)
climateqa/engine/vectorstore.py CHANGED
@@ -4,6 +4,7 @@
4
  import os
5
  from pinecone import Pinecone
6
  from langchain_community.vectorstores import Pinecone as PineconeVectorstore
 
7
 
8
  # LOAD ENVIRONMENT VARIABLES
9
  try:
@@ -13,7 +14,12 @@ except:
13
  pass
14
 
15
 
16
- def get_pinecone_vectorstore(embeddings,text_key = "content"):
 
 
 
 
 
17
 
18
  # # initialize pinecone
19
  # pinecone.init(
@@ -27,7 +33,7 @@ def get_pinecone_vectorstore(embeddings,text_key = "content"):
27
  # return vectorstore
28
 
29
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
30
- index = pc.Index(os.getenv("PINECONE_API_INDEX"))
31
 
32
  vectorstore = PineconeVectorstore(
33
  index, embeddings, text_key,
 
4
  import os
5
  from pinecone import Pinecone
6
  from langchain_community.vectorstores import Pinecone as PineconeVectorstore
7
+ from langchain_chroma import Chroma
8
 
9
  # LOAD ENVIRONMENT VARIABLES
10
  try:
 
14
  pass
15
 
16
 
17
+ def get_chroma_vectorstore(embedding_function, persist_directory="/home/dora/climate-question-answering/data/vectorstore"):
18
+ vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
19
+ return vectorstore
20
+
21
+
22
+ def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.getenv("PINECONE_API_INDEX")):
23
 
24
  # # initialize pinecone
25
  # pinecone.init(
 
33
  # return vectorstore
34
 
35
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
36
+ index = pc.Index(index_name)
37
 
38
  vectorstore = PineconeVectorstore(
39
  index, embeddings, text_key,
climateqa/event_handler.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables.schema import StreamEvent
2
+ from gradio import ChatMessage
3
+ from climateqa.engine.chains.prompts import audience_prompts
4
+ from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
5
+ import numpy as np
6
+
7
+ def init_audience(audience :str) -> str:
8
+ if audience == "Children":
9
+ audience_prompt = audience_prompts["children"]
10
+ elif audience == "General public":
11
+ audience_prompt = audience_prompts["general"]
12
+ elif audience == "Experts":
13
+ audience_prompt = audience_prompts["experts"]
14
+ else:
15
+ audience_prompt = audience_prompts["experts"]
16
+ return audience_prompt
17
+
18
+ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
19
+ """
20
+ Handles the retrieved documents and returns the HTML representation of the documents
21
+
22
+ Args:
23
+ event (StreamEvent): The event containing the retrieved documents
24
+ history (list[ChatMessage]): The current message history
25
+ used_documents (list[str]): The list of used documents
26
+
27
+ Returns:
28
+ tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
29
+ """
30
+ try:
31
+ docs = event["data"]["output"]["documents"]
32
+ docs_html = []
33
+ textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
34
+ for i, d in enumerate(textual_docs, 1):
35
+ if d.metadata["chunk_type"] == "text":
36
+ docs_html.append(make_html_source(d, i))
37
+
38
+ used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
39
+ if used_documents!=[]:
40
+ history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
41
+
42
+ docs_html = "".join(docs_html)
43
+
44
+ related_contents = event["data"]["output"]["related_contents"]
45
+
46
+ except Exception as e:
47
+ print(f"Error getting documents: {e}")
48
+ print(event)
49
+ return docs, docs_html, history, used_documents, related_contents
50
+
51
+ def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
52
+ """
53
+ Handles the streaming of the answer and updates the history with the new message content
54
+
55
+ Args:
56
+ history (list[ChatMessage]): The current message history
57
+ event (StreamEvent): The event containing the streamed answer
58
+ start_streaming (bool): A flag indicating if the streaming has started
59
+ new_message_content (str): The content of the new message
60
+
61
+ Returns:
62
+ tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content
63
+ """
64
+ if start_streaming == False:
65
+ start_streaming = True
66
+ history.append(ChatMessage(role="assistant", content = ""))
67
+ answer_message_content += event["data"]["chunk"].content
68
+ answer_message_content = parse_output_llm_with_sources(answer_message_content)
69
+ history[-1] = ChatMessage(role="assistant", content = answer_message_content)
70
+ # history.append(ChatMessage(role="assistant", content = new_message_content))
71
+ return history, start_streaming, answer_message_content
72
+
73
+ def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str:
74
+ """
75
+ Handles the retrieved OWID graphs and returns the HTML representation of the graphs
76
+
77
+ Args:
78
+ event (StreamEvent): The event containing the retrieved graphs
79
+ graphs_html (str): The current HTML representation of the graphs
80
+
81
+ Returns:
82
+ str: The updated HTML representation
83
+ """
84
+ try:
85
+ recommended_content = event["data"]["output"]["recommended_content"]
86
+
87
+ unique_graphs = []
88
+ seen_embeddings = set()
89
+
90
+ for x in recommended_content:
91
+ embedding = x.metadata["returned_content"]
92
+
93
+ # Check if the embedding has already been seen
94
+ if embedding not in seen_embeddings:
95
+ unique_graphs.append({
96
+ "embedding": embedding,
97
+ "metadata": {
98
+ "source": x.metadata["source"],
99
+ "category": x.metadata["category"]
100
+ }
101
+ })
102
+ # Add the embedding to the seen set
103
+ seen_embeddings.add(embedding)
104
+
105
+
106
+ categories = {}
107
+ for graph in unique_graphs:
108
+ category = graph['metadata']['category']
109
+ if category not in categories:
110
+ categories[category] = []
111
+ categories[category].append(graph['embedding'])
112
+
113
+
114
+ for category, embeddings in categories.items():
115
+ graphs_html += f"<h3>{category}</h3>"
116
+ for embedding in embeddings:
117
+ graphs_html += f"<div>{embedding}</div>"
118
+
119
+
120
+ except Exception as e:
121
+ print(f"Error getting graphs: {e}")
122
+
123
+ return graphs_html
climateqa/knowledge/__init__.py ADDED
File without changes
climateqa/{papers → knowledge}/openalex.py RENAMED
@@ -3,18 +3,32 @@ import networkx as nx
3
  import matplotlib.pyplot as plt
4
  from pyvis.network import Network
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
7
  import pyalex
8
 
9
  pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"
10
 
 
 
 
 
11
  class OpenAlex():
12
  def __init__(self):
13
  pass
14
 
15
 
16
-
17
- def search(self,keywords,n_results = 100,after = None,before = None):
18
 
19
  if isinstance(keywords,str):
20
  works = Works().search(keywords)
@@ -27,29 +41,36 @@ class OpenAlex():
27
  break
28
 
29
  df_works = pd.DataFrame(page)
30
- df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x))
 
 
 
 
 
 
31
  df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
32
  df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
33
- df_works["content"] = df_works["title"] + "\n" + df_works["abstract"]
 
 
 
 
 
 
34
 
 
35
  else:
36
- df_works = []
37
- for keyword in keywords:
38
- df_keyword = self.search(keyword,n_results = n_results,after = after,before = before)
39
- df_works.append(df_keyword)
40
- df_works = pd.concat(df_works,ignore_index=True,axis = 0)
41
- return df_works
42
 
43
 
44
  def rerank(self,query,df,reranker):
45
 
46
  scores = reranker.rank(
47
  query,
48
- df["content"].tolist(),
49
- top_k = len(df),
50
  )
51
- scores.sort(key = lambda x : x["corpus_id"])
52
- scores = [x["score"] for x in scores]
53
  df["rerank_score"] = scores
54
  return df
55
 
@@ -139,4 +160,36 @@ class OpenAlex():
139
  reconstructed[position] = token
140
 
141
  # Join the tokens to form the reconstructed sentence(s)
142
- return ' '.join(reconstructed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import matplotlib.pyplot as plt
4
  from pyvis.network import Network
5
 
6
+ from langchain_core.retrievers import BaseRetriever
7
+ from langchain_core.vectorstores import VectorStoreRetriever
8
+ from langchain_core.documents.base import Document
9
+ from langchain_core.vectorstores import VectorStore
10
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
11
+
12
+ from ..engine.utils import num_tokens_from_string
13
+
14
+ from typing import List
15
+ from pydantic import Field
16
+
17
  from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
18
  import pyalex
19
 
20
  pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"
21
 
22
+
23
+ def replace_nan_with_empty_dict(x):
24
+ return x if pd.notna(x) else {}
25
+
26
  class OpenAlex():
27
  def __init__(self):
28
  pass
29
 
30
 
31
+ def search(self,keywords:str,n_results = 100,after = None,before = None):
 
32
 
33
  if isinstance(keywords,str):
34
  works = Works().search(keywords)
 
41
  break
42
 
43
  df_works = pd.DataFrame(page)
44
+
45
+ if df_works.empty:
46
+ return df_works
47
+
48
+ df_works = df_works.dropna(subset = ["title"])
49
+ df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
50
+ df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
51
  df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
52
  df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
53
+ df_works["url"] = df_works["id"]
54
+ df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
55
+ df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
56
+
57
+ df_works = df_works.drop(columns = ["abstract_inverted_index"])
58
+ df_works["display_name"] = df_works["primary_location"].apply(lambda x :x["source"] if type(x) == dict and 'source' in x else "").apply(lambda x : x["display_name"] if type(x) == dict and "display_name" in x else "")
59
+ df_works["subtitle"] = df_works["title"].astype(str) + " - " + df_works["display_name"].astype(str) + " - " + df_works["publication_year"].astype(str)
60
 
61
+ return df_works
62
  else:
63
+ raise Exception("Keywords must be a string")
 
 
 
 
 
64
 
65
 
66
  def rerank(self,query,df,reranker):
67
 
68
  scores = reranker.rank(
69
  query,
70
+ df["content"].tolist()
 
71
  )
72
+ scores = sorted(scores.results, key = lambda x : x.document.doc_id)
73
+ scores = [x.score for x in scores]
74
  df["rerank_score"] = scores
75
  return df
76
 
 
160
  reconstructed[position] = token
161
 
162
  # Join the tokens to form the reconstructed sentence(s)
163
+ return ' '.join(reconstructed)
164
+
165
+
166
+
167
+ class OpenAlexRetriever(BaseRetriever):
168
+ min_year:int = 1960
169
+ max_year:int = None
170
+ k:int = 100
171
+
172
+ def _get_relevant_documents(
173
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
174
+ ) -> List[Document]:
175
+
176
+ openalex = OpenAlex()
177
+
178
+ # Search for documents
179
+ df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
180
+
181
+ docs = []
182
+ for i,row in df_docs.iterrows():
183
+ num_tokens = row["num_tokens"]
184
+
185
+ if num_tokens < 50 or num_tokens > 1000:
186
+ continue
187
+
188
+ doc = Document(
189
+ page_content = row["content"],
190
+ metadata = row.to_dict()
191
+ )
192
+ docs.append(doc)
193
+ return docs
194
+
195
+
climateqa/knowledge/retriever.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # https://github.com/langchain-ai/langchain/issues/8623
2
+
3
+ # import pandas as pd
4
+
5
+ # from langchain_core.retrievers import BaseRetriever
6
+ # from langchain_core.vectorstores import VectorStoreRetriever
7
+ # from langchain_core.documents.base import Document
8
+ # from langchain_core.vectorstores import VectorStore
9
+ # from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10
+
11
+ # from typing import List
12
+ # from pydantic import Field
13
+
14
+ # def _add_metadata_and_score(docs: List) -> Document:
15
+ # # Add score to metadata
16
+ # docs_with_metadata = []
17
+ # for i,(doc,score) in enumerate(docs):
18
+ # doc.page_content = doc.page_content.replace("\r\n"," ")
19
+ # doc.metadata["similarity_score"] = score
20
+ # doc.metadata["content"] = doc.page_content
21
+ # doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
22
+ # # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
23
+ # docs_with_metadata.append(doc)
24
+ # return docs_with_metadata
25
+
26
+ # class ClimateQARetriever(BaseRetriever):
27
+ # vectorstore:VectorStore
28
+ # sources:list = ["IPCC","IPBES","IPOS"]
29
+ # reports:list = []
30
+ # threshold:float = 0.6
31
+ # k_summary:int = 3
32
+ # k_total:int = 10
33
+ # namespace:str = "vectors",
34
+ # min_size:int = 200,
35
+
36
+
37
+
38
+ # def _get_relevant_documents(
39
+ # self, query: str, *, run_manager: CallbackManagerForRetrieverRun
40
+ # ) -> List[Document]:
41
+
42
+ # # Check if all elements in the list are either IPCC or IPBES
43
+ # assert isinstance(self.sources,list)
44
+ # assert self.sources
45
+ # assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
46
+ # assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
47
+
48
+ # # Prepare base search kwargs
49
+ # filters = {}
50
+
51
+ # if len(self.reports) > 0:
52
+ # filters["short_name"] = {"$in":self.reports}
53
+ # else:
54
+ # filters["source"] = { "$in":self.sources}
55
+
56
+ # # Search for k_summary documents in the summaries dataset
57
+ # filters_summaries = {
58
+ # **filters,
59
+ # "chunk_type":"text",
60
+ # "report_type": { "$in":["SPM"]},
61
+ # }
62
+
63
+ # docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
64
+ # docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
65
+ # # docs_summaries = []
66
+
67
+ # # Search for k_total - k_summary documents in the full reports dataset
68
+ # filters_full = {
69
+ # **filters,
70
+ # "chunk_type":"text",
71
+ # "report_type": { "$nin":["SPM"]},
72
+ # }
73
+ # k_full = self.k_total - len(docs_summaries)
74
+ # docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
75
+
76
+ # # Images
77
+ # filters_image = {
78
+ # **filters,
79
+ # "chunk_type":"image"
80
+ # }
81
+ # docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
82
+
83
+ # # docs_images = []
84
+
85
+ # # Concatenate documents
86
+ # # docs = docs_summaries + docs_full + docs_images
87
+
88
+ # # Filter if scores are below threshold
89
+ # # docs = [x for x in docs if x[1] > self.threshold]
90
+
91
+ # docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
92
+
93
+ # # Filter if length are below threshold
94
+ # docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
95
+ # docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
96
+
97
+
98
+ # return {
99
+ # "docs_summaries" : docs_summaries,
100
+ # "docs_full" : docs_full,
101
+ # "docs_images" : docs_images,
102
+ # }
climateqa/papers/__init__.py DELETED
@@ -1,43 +0,0 @@
1
- import pandas as pd
2
-
3
- from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
4
- import pyalex
5
-
6
- pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"
7
-
8
- class OpenAlex():
9
- def __init__(self):
10
- pass
11
-
12
-
13
-
14
- def search(self,keywords,n_results = 100,after = None,before = None):
15
- works = Works().search(keywords).get()
16
-
17
- for page in works.paginate(per_page=n_results):
18
- break
19
-
20
- df_works = pd.DataFrame(page)
21
-
22
- return works
23
-
24
-
25
- def make_network(self):
26
- pass
27
-
28
-
29
- def get_abstract_from_inverted_index(self,index):
30
-
31
- # Determine the maximum index to know the length of the reconstructed array
32
- max_index = max([max(positions) for positions in index.values()])
33
-
34
- # Initialize a list with placeholders for all positions
35
- reconstructed = [''] * (max_index + 1)
36
-
37
- # Iterate through the inverted index and place each token at its respective position(s)
38
- for token, positions in index.items():
39
- for position in positions:
40
- reconstructed[position] = token
41
-
42
- # Join the tokens to form the reconstructed sentence(s)
43
- return ' '.join(reconstructed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/utils.py CHANGED
@@ -20,3 +20,16 @@ def get_image_from_azure_blob_storage(path):
20
  file_object = get_file_from_azure_blob_storage(path)
21
  image = Image.open(file_object)
22
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  file_object = get_file_from_azure_blob_storage(path)
21
  image = Image.open(file_object)
22
  return image
23
+
24
+ def remove_duplicates_keep_highest_score(documents):
25
+ unique_docs = {}
26
+
27
+ for doc in documents:
28
+ doc_id = doc.metadata.get('doc_id')
29
+ if doc_id in unique_docs:
30
+ if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
31
+ unique_docs[doc_id] = doc
32
+ else:
33
+ unique_docs[doc_id] = doc
34
+
35
+ return list(unique_docs.values())
front/__init__.py ADDED
File without changes
front/callbacks.py ADDED
File without changes
front/utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ from collections import defaultdict
4
+ from climateqa.utils import get_image_from_azure_blob_storage
5
+ from climateqa.engine.chains.prompts import audience_prompts
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+
10
+
11
+ def make_pairs(lst:list)->list:
12
+ """from a list of even lenght, make tupple pairs"""
13
+ return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
14
+
15
+
16
+ def serialize_docs(docs:list)->list:
17
+ new_docs = []
18
+ for doc in docs:
19
+ new_doc = {}
20
+ new_doc["page_content"] = doc.page_content
21
+ new_doc["metadata"] = doc.metadata
22
+ new_docs.append(new_doc)
23
+ return new_docs
24
+
25
+
26
+
27
+ def parse_output_llm_with_sources(output:str)->str:
28
+ # Split the content into a list of text and "[Doc X]" references
29
+ content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
30
+ parts = []
31
+ for part in content_parts:
32
+ if part.startswith("Doc"):
33
+ subparts = part.split(",")
34
+ subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
35
+ subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
36
+ parts.append("".join(subparts))
37
+ else:
38
+ parts.append(part)
39
+ content_parts = "".join(parts)
40
+ return content_parts
41
+
42
+ def process_figures(docs:list)->tuple:
43
+ gallery=[]
44
+ used_figures =[]
45
+ figures = '<div class="figures-container"><p></p> </div>'
46
+ docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
47
+ for i, doc in enumerate(docs_figures):
48
+ if doc.metadata["chunk_type"] == "image":
49
+ if doc.metadata["figure_code"] != "N/A":
50
+ title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
51
+ else:
52
+ title = f"{doc.metadata['short_name']}"
53
+
54
+
55
+ if title not in used_figures:
56
+ used_figures.append(title)
57
+ try:
58
+ key = f"Image {i+1}"
59
+
60
+ image_path = doc.metadata["image_path"].split("documents/")[1]
61
+ img = get_image_from_azure_blob_storage(image_path)
62
+
63
+ # Convert the image to a byte buffer
64
+ buffered = BytesIO()
65
+ max_image_length = 500
66
+ img_resized = img.resize((max_image_length, int(max_image_length * img.size[1]/img.size[0])))
67
+ img_resized.save(buffered, format="PNG")
68
+
69
+ img_str = base64.b64encode(buffered.getvalue()).decode()
70
+
71
+ figures = figures + make_html_figure_sources(doc, i, img_str)
72
+ gallery.append(img)
73
+ except Exception as e:
74
+ print(f"Skipped adding image {i} because of {e}")
75
+
76
+ return figures, gallery
77
+
78
+
79
+ def generate_html_graphs(graphs:list)->str:
80
+ # Organize graphs by category
81
+ categories = defaultdict(list)
82
+ for graph in graphs:
83
+ category = graph['metadata']['category']
84
+ categories[category].append(graph['embedding'])
85
+
86
+ # Begin constructing the HTML
87
+ html_code = '''
88
+ <!DOCTYPE html>
89
+ <html lang="en">
90
+ <head>
91
+ <meta charset="UTF-8">
92
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
93
+ <title>Graphs by Category</title>
94
+ <style>
95
+ .tab-content {
96
+ display: none;
97
+ }
98
+ .tab-content.active {
99
+ display: block;
100
+ }
101
+ .tabs {
102
+ margin-bottom: 20px;
103
+ }
104
+ .tab-button {
105
+ background-color: #ddd;
106
+ border: none;
107
+ padding: 10px 20px;
108
+ cursor: pointer;
109
+ margin-right: 5px;
110
+ }
111
+ .tab-button.active {
112
+ background-color: #ccc;
113
+ }
114
+ </style>
115
+ <script>
116
+ function showTab(tabId) {
117
+ var contents = document.getElementsByClassName('tab-content');
118
+ var buttons = document.getElementsByClassName('tab-button');
119
+ for (var i = 0; i < contents.length; i++) {
120
+ contents[i].classList.remove('active');
121
+ buttons[i].classList.remove('active');
122
+ }
123
+ document.getElementById(tabId).classList.add('active');
124
+ document.querySelector('button[data-tab="'+tabId+'"]').classList.add('active');
125
+ }
126
+ </script>
127
+ </head>
128
+ <body>
129
+ <div class="tabs">
130
+ '''
131
+
132
+ # Add buttons for each category
133
+ for i, category in enumerate(categories.keys()):
134
+ active_class = 'active' if i == 0 else ''
135
+ html_code += f'<button class="tab-button {active_class}" onclick="showTab(\'tab-{i}\')" data-tab="tab-{i}">{category}</button>'
136
+
137
+ html_code += '</div>'
138
+
139
+ # Add content for each category
140
+ for i, (category, embeds) in enumerate(categories.items()):
141
+ active_class = 'active' if i == 0 else ''
142
+ html_code += f'<div id="tab-{i}" class="tab-content {active_class}">'
143
+ for embed in embeds:
144
+ html_code += embed
145
+ html_code += '</div>'
146
+
147
+ html_code += '''
148
+ </body>
149
+ </html>
150
+ '''
151
+
152
+ return html_code
153
+
154
+
155
+
156
+ def make_html_source(source,i):
157
+ meta = source.metadata
158
+ # content = source.page_content.split(":",1)[1].strip()
159
+ content = source.page_content.strip()
160
+
161
+ toc_levels = []
162
+ for j in range(2):
163
+ level = meta[f"toc_level{j}"]
164
+ if level != "N/A":
165
+ toc_levels.append(level)
166
+ else:
167
+ break
168
+ toc_levels = " > ".join(toc_levels)
169
+
170
+ if len(toc_levels) > 0:
171
+ name = f"<b>{toc_levels}</b><br/>{meta['name']}"
172
+ else:
173
+ name = meta['name']
174
+
175
+ score = meta['reranking_score']
176
+ if score > 0.8:
177
+ color = "score-green"
178
+ elif score > 0.5:
179
+ color = "score-orange"
180
+ else:
181
+ color = "score-red"
182
+
183
+ relevancy_score = f"<p class=relevancy-score>Relevancy score: <span class='{color}'>{score:.1%}</span></p>"
184
+
185
+ if meta["chunk_type"] == "text":
186
+
187
+ card = f"""
188
+ <div class="card" id="doc{i}">
189
+ <div class="card-content">
190
+ <h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
191
+ <p>{content}</p>
192
+ {relevancy_score}
193
+ </div>
194
+ <div class="card-footer">
195
+ <span>{name}</span>
196
+ <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
197
+ <span role="img" aria-label="Open PDF">🔗</span>
198
+ </a>
199
+ </div>
200
+ </div>
201
+ """
202
+
203
+ else:
204
+
205
+ if meta["figure_code"] != "N/A":
206
+ title = f"{meta['figure_code']} - {meta['short_name']}"
207
+ else:
208
+ title = f"{meta['short_name']}"
209
+
210
+ card = f"""
211
+ <div class="card card-image">
212
+ <div class="card-content">
213
+ <h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
214
+ <p class='ai-generated'>AI-generated description</p>
215
+ <p>{content}</p>
216
+
217
+ {relevancy_score}
218
+ </div>
219
+ <div class="card-footer">
220
+ <span>{name}</span>
221
+ <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
222
+ <span role="img" aria-label="Open PDF">🔗</span>
223
+ </a>
224
+ </div>
225
+ </div>
226
+ """
227
+
228
+ return card
229
+
230
+
231
+ def make_html_papers(df,i):
232
+ title = df['title'][i]
233
+ content = df['abstract'][i]
234
+ url = df['doi'][i]
235
+ publication_date = df['publication_year'][i]
236
+ subtitle = df['subtitle'][i]
237
+
238
+ card = f"""
239
+ <div class="card" id="doc{i}">
240
+ <div class="card-content">
241
+ <h2>Doc {i+1} - {title}</h2>
242
+ <p>{content}</p>
243
+ </div>
244
+ <div class="card-footer">
245
+ <span>{subtitle}</span>
246
+ <a href="{url}" target="_blank" class="pdf-link">
247
+ <span role="img" aria-label="Open paper">🔗</span>
248
+ </a>
249
+ </div>
250
+ </div>
251
+ """
252
+
253
+ return card
254
+
255
+
256
+ def make_html_figure_sources(source,i,img_str):
257
+ meta = source.metadata
258
+ content = source.page_content.strip()
259
+
260
+ score = meta['reranking_score']
261
+ if score > 0.8:
262
+ color = "score-green"
263
+ elif score > 0.5:
264
+ color = "score-orange"
265
+ else:
266
+ color = "score-red"
267
+
268
+ toc_levels = []
269
+ if len(toc_levels) > 0:
270
+ name = f"<b>{toc_levels}</b><br/>{meta['name']}"
271
+ else:
272
+ name = meta['name']
273
+
274
+ relevancy_score = f"<p class=relevancy-score>Relevancy score: <span class='{color}'>{score:.1%}</span></p>"
275
+
276
+ if meta["figure_code"] != "N/A":
277
+ title = f"{meta['figure_code']} - {meta['short_name']}"
278
+ else:
279
+ title = f"{meta['short_name']}"
280
+
281
+ card = f"""
282
+ <div class="card card-image">
283
+ <div class="card-content">
284
+ <h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
285
+ <img src="data:image/png;base64, { img_str }" alt="Alt text" />
286
+ <p class='ai-generated'>AI-generated description</p>
287
+
288
+ <p>{content}</p>
289
+
290
+ {relevancy_score}
291
+ </div>
292
+ <div class="card-footer">
293
+ <span>{name}</span>
294
+ <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
295
+ <span role="img" aria-label="Open PDF">🔗</span>
296
+ </a>
297
+ </div>
298
+ </div>
299
+ """
300
+ return card
301
+
302
+
303
+
304
+ def make_toolbox(tool_name,description = "",checked = False,elem_id = "toggle"):
305
+
306
+ if checked:
307
+ span = "<span class='checkmark'>&#10003;</span>"
308
+ else:
309
+ span = "<span class='loader'></span>"
310
+
311
+ # toolbox = f"""
312
+ # <div class="dropdown">
313
+ # <label for="{elem_id}" class="dropdown-toggle">
314
+ # {span}
315
+ # {tool_name}
316
+ # <span class="caret"></span>
317
+ # </label>
318
+ # <input type="checkbox" id="{elem_id}" hidden/>
319
+ # <div class="dropdown-content">
320
+ # <p>{description}</p>
321
+ # </div>
322
+ # </div>
323
+ # """
324
+
325
+
326
+ toolbox = f"""
327
+ <div class="dropdown">
328
+ <label for="{elem_id}" class="dropdown-toggle">
329
+ {span}
330
+ {tool_name}
331
+ </label>
332
+ </div>
333
+ """
334
+
335
+ return toolbox
requirements.txt CHANGED
@@ -1,13 +1,21 @@
1
- gradio==4.19.1
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob
4
  python-dotenv==1.0.0
5
- langchain==0.1.4
6
- langchain_openai==0.0.6
7
- pinecone-client==3.0.2
 
8
  sentence-transformers==2.6.0
9
  huggingface-hub
10
- msal
11
  pyalex==0.13
12
  networkx==3.2.1
13
- pyvis==0.3.2
 
 
 
 
 
 
 
 
 
1
+ gradio==5.0.2
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob
4
  python-dotenv==1.0.0
5
+ langchain==0.2.1
6
+ langchain_openai==0.1.7
7
+ langgraph==0.0.55
8
+ pinecone-client==4.1.0
9
  sentence-transformers==2.6.0
10
  huggingface-hub
 
11
  pyalex==0.13
12
  networkx==3.2.1
13
+ pyvis==0.3.2
14
+ flashrank==0.2.5
15
+ rerankers==0.3.0
16
+ torch==2.3.0
17
+ nvidia-cudnn-cu12==8.9.2.26
18
+ langchain-community==0.2
19
+ msal==1.31
20
+ matplotlib==3.9.2
21
+ gradio-modal==0.0.4
sandbox/20240310 - CQA - Semantic Routing 1.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
sandbox/20240702 - CQA - Graph Functionality.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
sandbox/20241104 - CQA - StepByStep CQA.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
style.css CHANGED
@@ -3,6 +3,87 @@
3
  --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
4
  } */
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  .warning-box {
7
  background-color: #fff3cd;
8
  border: 1px solid #ffeeba;
@@ -57,14 +138,26 @@ body.dark .tip-box * {
57
 
58
  .message{
59
  font-size:14px !important;
 
 
 
 
 
 
 
60
  }
61
-
62
 
63
  a {
64
  text-decoration: none;
65
  color: inherit;
66
  }
67
 
 
 
 
 
 
 
68
  .card {
69
  background-color: white;
70
  border-radius: 10px;
@@ -128,94 +221,183 @@ a {
128
  border:none;
129
  }
130
 
131
- /* .gallery-item > div:hover{
132
- background-color:#7494b0 !important;
133
- color:white!important;
134
- }
135
 
136
- .gallery-item:hover{
137
- border:#7494b0 !important;
138
  }
139
 
140
- .gallery-item > div{
141
- background-color:white !important;
142
- color:#577b9b!important;
143
  }
144
 
145
- .label{
146
- color:#577b9b!important;
147
- } */
148
 
149
- /* .paginate{
150
- color:#577b9b!important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  } */
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
 
 
 
 
154
 
155
- /* span[data-testid="block-info"]{
156
- background:none !important;
157
- color:#577b9b;
158
- } */
159
 
160
- /* Pseudo-element for the circularly cropped picture */
161
- /* .message.bot::before {
162
- content: '';
163
  position: absolute;
164
- top: -10px;
165
- left: -10px;
166
- width: 30px;
167
- height: 30px;
168
- background-image: var(--user-image);
169
- background-size: cover;
170
- background-position: center;
 
 
171
  border-radius: 50%;
172
- z-index: 10;
173
- }
174
- */
175
-
176
- label.selected{
177
- background:none !important;
178
  }
179
-
180
- #submit-button{
181
- padding:0px !important;
182
  }
183
 
184
 
 
185
  @media screen and (min-width: 1024px) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  div#tab-examples{
187
  height:calc(100vh - 190px) !important;
188
- overflow-y: auto;
 
189
  }
190
 
191
  div#sources-textbox{
192
  height:calc(100vh - 190px) !important;
193
- overflow-y: auto !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  }
195
 
196
  div#tab-config{
197
  height:calc(100vh - 190px) !important;
198
- overflow-y: auto !important;
 
199
  }
200
 
201
- div#chatbot-row{
202
- height:calc(100vh - 90px) !important;
 
 
203
  }
204
 
205
- div#chatbot{
206
- height:calc(100vh - 170px) !important;
 
207
  }
 
 
208
 
209
  .max-height{
210
  height:calc(100vh - 90px) !important;
 
211
  overflow-y: auto;
212
  }
 
213
 
214
- /* .tabitem:nth-child(n+3) {
215
- padding-top:30px;
216
- padding-left:40px;
217
- padding-right:40px;
218
- } */
219
  }
220
 
221
  footer {
@@ -258,21 +440,33 @@ footer {
258
  /* ... add other mobile-specific styles ... */
259
  }
260
 
 
 
 
 
 
 
 
261
 
262
- body.dark .card{
263
- background-color: #374151;
264
- }
265
-
266
- body.dark .card-content h2{
267
- color:#f4dbd3 !important;
268
- }
269
-
270
- body.dark .card-footer {
271
- background-color: #404652;
272
- }
273
 
274
- body.dark .card-footer span {
275
- color:white !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  }
277
 
278
 
@@ -318,7 +512,7 @@ span.chatbot > p > img{
318
  }
319
 
320
  .card-image > .card-content{
321
- background-color:#f1f7fa !important;
322
  }
323
 
324
 
@@ -344,8 +538,7 @@ span.chatbot > p > img{
344
  }
345
 
346
  #dropdown-samples{
347
- /*! border:none !important; */
348
- /*! border-width:0px !important; */
349
  background:none !important;
350
 
351
  }
@@ -363,3 +556,147 @@ span.chatbot > p > img{
363
  .a-doc-ref{
364
  text-decoration: none !important;
365
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
4
  } */
5
 
6
+ #tab-recommended_content{
7
+ padding-top: 0px;
8
+ padding-left : 0px;
9
+ padding-right: 0px;
10
+ }
11
+ #group-subtabs {
12
+ /* display: block; */
13
+ width: 100%; /* Ensures the parent uses the full width */
14
+ position : sticky;
15
+ }
16
+
17
+ #group-subtabs .tab-container {
18
+ display: flex;
19
+ text-align: center;
20
+ width: 100%; /* Ensures the tabs span the full width */
21
+ }
22
+
23
+ #group-subtabs .tab-container button {
24
+ flex: 1; /* Makes each button take equal width */
25
+ }
26
+
27
+
28
+ #papers-summary-popup button span{
29
+ /* make label of accordio in bold, center, and bigger */
30
+ font-size: 16px;
31
+ font-weight: bold;
32
+ text-align: center;
33
+
34
+ }
35
+
36
+ #papers-relevant-popup span{
37
+ /* make label of accordio in bold, center, and bigger */
38
+ font-size: 16px;
39
+ font-weight: bold;
40
+ text-align: center;
41
+ }
42
+
43
+
44
+
45
+ #tab-citations .button{
46
+ padding: 12px 16px;
47
+ font-size: 16px;
48
+ font-weight: bold;
49
+ cursor: pointer;
50
+ border: none;
51
+ outline: none;
52
+ text-align: left;
53
+ transition: background-color 0.3s ease;
54
+ }
55
+
56
+
57
+ .gradio-container {
58
+ width: 100%!important;
59
+ max-width: 100% !important;
60
+ }
61
+
62
+ /* fix for huggingface infinite growth*/
63
+ main.flex.flex-1.flex-col {
64
+ max-height: 95vh !important;
65
+ }
66
+
67
+ button#show-figures{
68
+ /* Base styles */
69
+ background-color: #f5f5f5;
70
+ border: 1px solid #e0e0e0;
71
+ border-radius: 4px;
72
+ color: #333333;
73
+ cursor: pointer;
74
+ width: 100%;
75
+ text-align: center;
76
+ }
77
+
78
+ .avatar-container.svelte-1x5p6hu:not(.thumbnail-item) img {
79
+ width: 100%;
80
+ height: 100%;
81
+ object-fit: cover;
82
+ border-radius: 50%;
83
+ padding: 0px;
84
+ margin: 0px;
85
+ }
86
+
87
  .warning-box {
88
  background-color: #fff3cd;
89
  border: 1px solid #ffeeba;
 
138
 
139
  .message{
140
  font-size:14px !important;
141
+
142
+ }
143
+ .card-content img {
144
+ display: block;
145
+ margin: auto;
146
+ max-width: 100%; /* Ensures the image is responsive */
147
+ height: auto;
148
  }
 
149
 
150
  a {
151
  text-decoration: none;
152
  color: inherit;
153
  }
154
 
155
+ .doc-ref sup{
156
+ color:#dc2626!important;
157
+ /* margin-right:1px; */
158
+ }
159
+
160
+
161
  .card {
162
  background-color: white;
163
  border-radius: 10px;
 
221
  border:none;
222
  }
223
 
 
 
 
 
224
 
225
+ label.selected{
226
+ background: #93c5fd !important;
227
  }
228
 
229
+ #submit-button{
230
+ padding:0px !important;
 
231
  }
232
 
233
+ #modal-config .block.modal-block.padded {
234
+ padding-top: 25px;
235
+ height: 100vh;
236
 
237
+ }
238
+ #modal-config .modal-container{
239
+ margin: 0px;
240
+ padding: 0px;
241
+ }
242
+ /* Modal styles */
243
+ #modal-config {
244
+ position: fixed;
245
+ top: 0;
246
+ left: 0;
247
+ height: 100vh;
248
+ width: 500px;
249
+ background-color: white;
250
+ box-shadow: 2px 0 10px rgba(0, 0, 0, 0.1);
251
+ z-index: 1000;
252
+ padding: 15px;
253
+ transform: none;
254
+ }
255
+ #modal-config .close{
256
+ display: none;
257
+ }
258
+
259
+ /* Push main content to the right when modal is open */
260
+ /* .modal ~ * {
261
+ margin-left: 300px;
262
+ transition: margin-left 0.3s ease;
263
  } */
264
 
265
+ #modal-config .modal .wrap ul{
266
+ position:static;
267
+ top: 100%;
268
+ left: 0;
269
+ /* min-height: 100px; */
270
+ height: 100%;
271
+ /* margin-top: 0; */
272
+ z-index: 9999;
273
+ pointer-events: auto;
274
+ height: 200px;
275
+ }
276
+ #config-button{
277
+ background: none;
278
+ border: none;
279
+ padding: 8px;
280
+ cursor: pointer;
281
+ width: 40px;
282
+ height: 40px;
283
+ display: flex;
284
+ align-items: center;
285
+ justify-content: center;
286
+ border-radius: 50%;
287
+ transition: background-color 0.2s;
288
+ }
289
 
290
+ #config-button::before {
291
+ content: '⚙️';
292
+ font-size: 20px;
293
+ }
294
 
295
+ #config-button:hover {
296
+ background-color: rgba(0, 0, 0, 0.1);
297
+ }
 
298
 
299
+ #checkbox-config{
300
+ display: block;
 
301
  position: absolute;
302
+ background: none;
303
+ border: none;
304
+ padding: 8px;
305
+ cursor: pointer;
306
+ width: 40px;
307
+ height: 40px;
308
+ display: flex;
309
+ align-items: center;
310
+ justify-content: center;
311
  border-radius: 50%;
312
+ transition: background-color 0.2s;
313
+ font-size: 20px;
314
+ text-align: center;
 
 
 
315
  }
316
+ #checkbox-config:checked{
317
+ display: block;
 
318
  }
319
 
320
 
321
+
322
  @media screen and (min-width: 1024px) {
323
+ /* Additional style for scrollable tab content */
324
+ /* div#tab-recommended_content {
325
+ overflow-y: auto;
326
+ max-height: 80vh;
327
+ } */
328
+
329
+ .gradio-container {
330
+ max-height: calc(100vh - 190px) !important;
331
+ overflow: hidden;
332
+ }
333
+ /* div#chatbot{
334
+ height:calc(100vh - 170px) !important;
335
+ max-height:calc(100vh - 170px) !important;
336
+
337
+ } */
338
+
339
+
340
+
341
  div#tab-examples{
342
  height:calc(100vh - 190px) !important;
343
+ overflow-y: scroll !important;
344
+ /* overflow-y: auto; */
345
  }
346
 
347
  div#sources-textbox{
348
  height:calc(100vh - 190px) !important;
349
+ overflow-y: scroll !important;
350
+ /* overflow-y: auto !important; */
351
+ }
352
+ div#graphs-container{
353
+ height:calc(100vh - 210px) !important;
354
+ overflow-y: scroll !important;
355
+ }
356
+
357
+ div#sources-figures{
358
+ height:calc(100vh - 300px) !important;
359
+ max-height: 90vh !important;
360
+ overflow-y: scroll !important;
361
+ }
362
+
363
+ div#graphs-container{
364
+ height:calc(100vh - 300px) !important;
365
+ max-height: 90vh !important;
366
+ overflow-y: scroll !important;
367
+ }
368
+
369
+ div#tab-citations{
370
+ height:calc(100vh - 300px) !important;
371
+ max-height: 90vh !important;
372
+ overflow-y: scroll !important;
373
  }
374
 
375
  div#tab-config{
376
  height:calc(100vh - 190px) !important;
377
+ overflow-y: scroll !important;
378
+ /* overflow-y: auto !important; */
379
  }
380
 
381
+ /* Force container to respect height limits */
382
+ .main-component{
383
+ contain: size layout;
384
+ overflow: hidden;
385
  }
386
 
387
+
388
+ div#chatbot-row{
389
+ max-height:calc(100vh - 90px) !important;
390
  }
391
+ /*
392
+
393
 
394
  .max-height{
395
  height:calc(100vh - 90px) !important;
396
+ max-height:calc(100vh - 90px) !important;
397
  overflow-y: auto;
398
  }
399
+ */
400
 
 
 
 
 
 
401
  }
402
 
403
  footer {
 
440
  /* ... add other mobile-specific styles ... */
441
  }
442
 
443
+ @media (prefers-color-scheme: dark) {
444
+ .card{
445
+ background-color: #374151;
446
+ }
447
+ .card-image > .card-content{
448
+ background-color: rgb(55, 65, 81) !important;
449
+ }
450
 
451
+ .card-footer {
452
+ background-color: #404652;
453
+ }
 
 
 
 
 
 
 
 
454
 
455
+ .container > .wrap{
456
+ background-color: #374151 !important;
457
+ color:white !important;
458
+ }
459
+ .card-content h2{
460
+ color:#e7754f !important;
461
+ }
462
+ .doc-ref sup{
463
+ color:rgb(235 109 35)!important;
464
+ /* margin-right:1px; */
465
+ }
466
+ .card-footer span {
467
+ color:white !important;
468
+ }
469
+
470
  }
471
 
472
 
 
512
  }
513
 
514
  .card-image > .card-content{
515
+ background-color:#f1f7fa;
516
  }
517
 
518
 
 
538
  }
539
 
540
  #dropdown-samples{
541
+
 
542
  background:none !important;
543
 
544
  }
 
556
  .a-doc-ref{
557
  text-decoration: none !important;
558
  }
559
+
560
+
561
+ .dropdown {
562
+ position: relative;
563
+ display:inline-block;
564
+ margin-bottom: 10px;
565
+ }
566
+
567
+ .dropdown-toggle {
568
+ background-color: #f2f2f2;
569
+ color: black;
570
+ padding: 10px;
571
+ font-size: 16px;
572
+ cursor: pointer;
573
+ display: block;
574
+ width: 400px; /* Adjust width as needed */
575
+ position: relative;
576
+ display: flex;
577
+ align-items: center; /* Vertically center the contents */
578
+ justify-content: left;
579
+ }
580
+
581
+ .dropdown-toggle .caret {
582
+ content: "";
583
+ position: absolute;
584
+ right: 10px;
585
+ top: 50%;
586
+ border-left: 5px solid transparent;
587
+ border-right: 5px solid transparent;
588
+ border-top: 5px solid black;
589
+ transform: translateY(-50%);
590
+ }
591
+
592
+ input[type="checkbox"] {
593
+ display: none !important;
594
+ }
595
+
596
+ input[type="checkbox"]:checked + .dropdown-content {
597
+ display: block;
598
+ }
599
+
600
+ #checkbox-chat input[type="checkbox"] {
601
+ display: flex !important;
602
+ }
603
+
604
+ .dropdown-content {
605
+ display: none;
606
+ position: absolute;
607
+ background-color: #f9f9f9;
608
+ min-width: 300px;
609
+ box-shadow: 0 8px 16px 0 rgba(0,0,0,0.2);
610
+ z-index: 1;
611
+ padding: 12px;
612
+ border: 1px solid #ccc;
613
+ }
614
+
615
+ input[type="checkbox"]:checked + .dropdown-toggle + .dropdown-content {
616
+ display: block;
617
+ }
618
+
619
+ input[type="checkbox"]:checked + .dropdown-toggle .caret {
620
+ border-top: 0;
621
+ border-bottom: 5px solid black;
622
+ }
623
+
624
+ .loader {
625
+ border: 1px solid #d0d0d0 !important; /* Light grey background */
626
+ border-top: 1px solid #db3434 !important; /* Blue color */
627
+ border-right: 1px solid #3498db !important; /* Blue color */
628
+ border-radius: 50%;
629
+ width: 20px;
630
+ height: 20px;
631
+ animation: spin 2s linear infinite;
632
+ display:inline-block;
633
+ margin-right:10px !important;
634
+ }
635
+
636
+ .checkmark{
637
+ color:green !important;
638
+ font-size:18px;
639
+ margin-right:10px !important;
640
+ }
641
+
642
+ @keyframes spin {
643
+ 0% { transform: rotate(0deg); }
644
+ 100% { transform: rotate(360deg); }
645
+ }
646
+
647
+
648
+ .relevancy-score{
649
+ margin-top:10px !important;
650
+ font-size:10px !important;
651
+ font-style:italic;
652
+ }
653
+
654
+ .score-green{
655
+ color:green !important;
656
+ }
657
+
658
+ .score-orange{
659
+ color:orange !important;
660
+ }
661
+
662
+ .score-red{
663
+ color:red !important;
664
+ }
665
+
666
+ /* Mobile specific adjustments */
667
+ @media screen and (max-width: 767px) {
668
+ div#tab-recommended_content {
669
+ max-height: 50vh; /* Reduce height for smaller screens */
670
+ overflow-y: auto;
671
+ }
672
+ }
673
+
674
+ /* Additional style for scrollable tab content */
675
+ div#tab-saved-graphs {
676
+ overflow-y: auto; /* Enable vertical scrolling */
677
+ max-height: 80vh; /* Adjust height as needed */
678
+ }
679
+
680
+ /* Mobile specific adjustments */
681
+ @media screen and (max-width: 767px) {
682
+ div#tab-saved-graphs {
683
+ max-height: 50vh; /* Reduce height for smaller screens */
684
+ overflow-y: auto;
685
+ }
686
+ }
687
+ .message-buttons-left.panel.message-buttons.with-avatar {
688
+ display: none;
689
+ }
690
+
691
+
692
+ /* Specific fixes for Hugging Face Space iframe */
693
+ .h-full {
694
+ height: auto !important;
695
+ min-height: 0 !important;
696
+ }
697
+
698
+ .space-content {
699
+ height: auto !important;
700
+ max-height: 100vh !important;
701
+ overflow: hidden;
702
+ }
test.json ADDED
File without changes