lqhl commited on
Commit
0e573d0
Β·
verified Β·
1 Parent(s): ff7aa95

Synced repo using 'sync_with_huggingface' Github Action

Browse files
.streamlit/secrets.example.toml CHANGED
@@ -1,6 +1,9 @@
1
- MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud"
2
  MYSCALE_PORT = 443
3
  MYSCALE_USER = "chatdata"
4
  MYSCALE_PASSWORD = "myscale_rocks"
5
  OPENAI_API_BASE = "https://api.openai.com/v1"
6
  OPENAI_API_KEY = "<your-openai-key>"
 
 
 
 
1
+ MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale
2
  MYSCALE_PORT = 443
3
  MYSCALE_USER = "chatdata"
4
  MYSCALE_PASSWORD = "myscale_rocks"
5
  OPENAI_API_BASE = "https://api.openai.com/v1"
6
  OPENAI_API_KEY = "<your-openai-key>"
7
+ UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
8
+ AUTH0_DOMAIN = "<your-auth0-domain>" # optional if you don't user management
9
+ AUTH0_CLIENT_ID = "<your-auth0-client-id>" # optiona
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import json
2
- import time
3
  import pandas as pd
4
  from os import environ
5
  import streamlit as st
@@ -13,10 +11,10 @@ from login import login, back_to_main
13
  from lib.helper import build_tools, build_all, sel_map, display
14
 
15
 
16
-
17
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
18
 
19
- st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
 
20
  st.markdown(
21
  f"""
22
  <style>
@@ -36,11 +34,12 @@ if login():
36
  if "user_name" in st.session_state:
37
  chat_page()
38
  elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
39
-
40
  sel = st.selectbox('Choose the knowledge base you want to ask with:',
41
- options=['ArXiv Papers', 'Wikipedia'])
42
  sel_map[sel]['hint']()
43
- tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
 
44
  with tab_sql:
45
  sel_map[sel]['hint_sql']()
46
  st.text_input("Ask a question:", key='query_sql')
@@ -85,7 +84,6 @@ if login():
85
  st.write('Oops 😡 Something bad happened...')
86
  raise e
87
 
88
-
89
  with tab_self_query:
90
  st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
91
  st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
@@ -132,4 +130,4 @@ if login():
132
  docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
133
  except Exception as e:
134
  st.write('Oops 😡 Something bad happened...')
135
- raise e
 
 
 
1
  import pandas as pd
2
  from os import environ
3
  import streamlit as st
 
11
  from lib.helper import build_tools, build_all, sel_map, display
12
 
13
 
 
14
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
15
 
16
+ st.set_page_config(page_title="ChatData",
17
+ page_icon="https://myscale.com/favicon.ico")
18
  st.markdown(
19
  f"""
20
  <style>
 
34
  if "user_name" in st.session_state:
35
  chat_page()
36
  elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
37
+
38
  sel = st.selectbox('Choose the knowledge base you want to ask with:',
39
+ options=['ArXiv Papers', 'Wikipedia'])
40
  sel_map[sel]['hint']()
41
+ tab_sql, tab_self_query = st.tabs(
42
+ ['Vector SQL', 'Self-Query Retrievers'])
43
  with tab_sql:
44
  sel_map[sel]['hint_sql']()
45
  st.text_input("Ask a question:", key='query_sql')
 
84
  st.write('Oops 😡 Something bad happened...')
85
  raise e
86
 
 
87
  with tab_self_query:
88
  st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
89
  st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
 
130
  docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
131
  except Exception as e:
132
  st.write('Oops 😡 Something bad happened...')
133
+ raise e
callbacks/arxiv_callbacks.py CHANGED
@@ -8,7 +8,6 @@ from langchain.callbacks.streamlit.streamlit_callback_handler import (
8
  StreamlitCallbackHandler,
9
  )
10
  from langchain.schema.output import LLMResult
11
- from streamlit.delta_generator import DeltaGenerator
12
 
13
 
14
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
@@ -26,7 +25,8 @@ class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
26
  self.progress_bar.progress(value=0.6, text="Searching in DB...")
27
  if "repr" in outputs:
28
  st.markdown("### Generated Filter")
29
- st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
 
30
 
31
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
32
  pass
@@ -88,7 +88,8 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
88
  st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
89
  print(f"Vector SQL: {text}")
90
  self.prog_value += self.prog_interval
91
- self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
 
92
 
93
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
94
  cid = ".".join(serialized["id"])
 
8
  StreamlitCallbackHandler,
9
  )
10
  from langchain.schema.output import LLMResult
 
11
 
12
 
13
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
 
25
  self.progress_bar.progress(value=0.6, text="Searching in DB...")
26
  if "repr" in outputs:
27
  st.markdown("### Generated Filter")
28
+ st.markdown(
29
+ f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
30
 
31
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
32
  pass
 
88
  st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
89
  print(f"Vector SQL: {text}")
90
  self.prog_value += self.prog_interval
91
+ self.progress_bar.progress(
92
+ value=self.prog_value, text="Searching in DB...")
93
 
94
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
95
  cid = ".".join(serialized["id"])
chains/arxiv_chains.py CHANGED
@@ -8,7 +8,6 @@ from langchain.callbacks.manager import (
8
  CallbackManagerForChainRun,
9
  )
10
  from langchain.embeddings.base import Embeddings
11
- from langchain.schema import BaseRetriever
12
  from langchain.callbacks.manager import Callbacks
13
  from langchain.schema.prompt_template import format_document
14
  from langchain.docstore.document import Document
@@ -20,11 +19,12 @@ from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
20
 
21
  logger = logging.getLogger()
22
 
 
23
  class MyScaleWithoutMetadataJson(MyScale):
24
  def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
25
  super().__init__(embedding, config, **kwargs)
26
  self.must_have_cols: List[str] = must_have_cols
27
-
28
  def _build_qstr(
29
  self, q_emb: List[float], topk: int, where_str: Optional[str] = None
30
  ) -> str:
@@ -43,7 +43,7 @@ class MyScaleWithoutMetadataJson(MyScale):
43
  LIMIT {topk}
44
  """
45
  return q_str
46
-
47
  def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
48
  q_str = self._build_qstr(embedding, k, where_str)
49
  try:
@@ -55,9 +55,11 @@ class MyScaleWithoutMetadataJson(MyScale):
55
  for r in self.client.query(q_str).named_results()
56
  ]
57
  except Exception as e:
58
- logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
 
59
  return []
60
 
 
61
  class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
62
  """Based on VectorSQLOutputParser
63
  It also modify the SQL to get all columns
@@ -73,9 +75,11 @@ class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
73
  start = text.upper().find("SELECT")
74
  if start >= 0:
75
  end = text.upper().find("FROM")
76
- text = text.replace(text[start + len("SELECT") + 1 : end - 1], ", ".join(self.must_have_columns))
 
77
  return super().parse(text)
78
 
 
79
  class ArXivStuffDocumentChain(StuffDocumentsChain):
80
  """Combine arxiv documents with PDF reference number"""
81
 
@@ -172,8 +176,7 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
172
  answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
173
  sources.append(d)
174
  ref_cnt += 1
175
-
176
-
177
  result: Dict[str, Any] = {
178
  self.answer_key: answer,
179
  self.sources_answer_key: sources,
@@ -191,4 +194,4 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
191
 
192
  @property
193
  def _chain_type(self) -> str:
194
- return "arxiv_qa_with_sources_chain"
 
8
  CallbackManagerForChainRun,
9
  )
10
  from langchain.embeddings.base import Embeddings
 
11
  from langchain.callbacks.manager import Callbacks
12
  from langchain.schema.prompt_template import format_document
13
  from langchain.docstore.document import Document
 
19
 
20
  logger = logging.getLogger()
21
 
22
+
23
  class MyScaleWithoutMetadataJson(MyScale):
24
  def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
25
  super().__init__(embedding, config, **kwargs)
26
  self.must_have_cols: List[str] = must_have_cols
27
+
28
  def _build_qstr(
29
  self, q_emb: List[float], topk: int, where_str: Optional[str] = None
30
  ) -> str:
 
43
  LIMIT {topk}
44
  """
45
  return q_str
46
+
47
  def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
48
  q_str = self._build_qstr(embedding, k, where_str)
49
  try:
 
55
  for r in self.client.query(q_str).named_results()
56
  ]
57
  except Exception as e:
58
+ logger.error(
59
+ f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
60
  return []
61
 
62
+
63
  class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
64
  """Based on VectorSQLOutputParser
65
  It also modify the SQL to get all columns
 
75
  start = text.upper().find("SELECT")
76
  if start >= 0:
77
  end = text.upper().find("FROM")
78
+ text = text.replace(
79
+ text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
80
  return super().parse(text)
81
 
82
+
83
  class ArXivStuffDocumentChain(StuffDocumentsChain):
84
  """Combine arxiv documents with PDF reference number"""
85
 
 
176
  answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
177
  sources.append(d)
178
  ref_cnt += 1
179
+
 
180
  result: Dict[str, Any] = {
181
  self.answer_key: answer,
182
  self.sources_answer_key: sources,
 
194
 
195
  @property
196
  def _chain_type(self) -> str:
197
+ return "arxiv_qa_with_sources_chain"
chat.py CHANGED
@@ -8,9 +8,6 @@ from lib.sessions import SessionManager
8
  from lib.private_kb import PrivateKnowledgeBase
9
  from langchain.schema import HumanMessage, FunctionMessage
10
  from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
11
- from langchain.callbacks.streamlit.streamlit_callback_handler import (
12
- StreamlitCallbackHandler,
13
- )
14
  from lib.json_conv import CustomJSONDecoder
15
 
16
  from lib.helper import (
@@ -313,7 +310,8 @@ def chat_page():
313
  key="b_tool_files",
314
  format_func=lambda x: x["file_name"],
315
  )
316
- st.text_input("Tool Name", "get_relevant_documents", key="b_tool_name")
 
317
  st.text_input(
318
  "Tool Description",
319
  "Searches among user's private files and returns related documents",
@@ -359,14 +357,16 @@ def chat_page():
359
  )
360
  st.markdown("### Uploaded Files")
361
  st.dataframe(
362
- st.session_state.private_kb.list_files(st.session_state.user_name),
 
363
  use_container_width=True,
364
  )
365
  col_1, col_2 = st.columns(2)
366
  with col_1:
367
  st.button("Add Files", on_click=add_file)
368
  with col_2:
369
- st.button("Clear Files and All Tools", on_click=clear_files)
 
370
 
371
  st.button("Clear Chat History", on_click=clear_history)
372
  st.button("Logout", on_click=back_to_main)
 
8
  from lib.private_kb import PrivateKnowledgeBase
9
  from langchain.schema import HumanMessage, FunctionMessage
10
  from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
 
 
 
11
  from lib.json_conv import CustomJSONDecoder
12
 
13
  from lib.helper import (
 
310
  key="b_tool_files",
311
  format_func=lambda x: x["file_name"],
312
  )
313
+ st.text_input(
314
+ "Tool Name", "get_relevant_documents", key="b_tool_name")
315
  st.text_input(
316
  "Tool Description",
317
  "Searches among user's private files and returns related documents",
 
357
  )
358
  st.markdown("### Uploaded Files")
359
  st.dataframe(
360
+ st.session_state.private_kb.list_files(
361
+ st.session_state.user_name),
362
  use_container_width=True,
363
  )
364
  col_1, col_2 = st.columns(2)
365
  with col_1:
366
  st.button("Add Files", on_click=add_file)
367
  with col_2:
368
+ st.button("Clear Files and All Tools",
369
+ on_click=clear_files)
370
 
371
  st.button("Clear Chat History", on_click=clear_history)
372
  st.button("Logout", on_click=back_to_main)
lib/helper.py CHANGED
@@ -4,10 +4,8 @@ import time
4
  import hashlib
5
  from typing import Dict, Any, List, Tuple
6
  import re
7
- import pandas as pd
8
  from os import environ
9
  import streamlit as st
10
- import datetime
11
  from langchain.schema import BaseRetriever
12
  from langchain.tools import Tool
13
  from langchain.pydantic_v1 import BaseModel, Field
@@ -20,7 +18,7 @@ except ImportError:
20
  from sqlalchemy.ext.declarative import declarative_base
21
  from sqlalchemy.orm import sessionmaker
22
  from clickhouse_sqlalchemy import (
23
- Table, make_session, get_declarative_base, types, engines
24
  )
25
  from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
26
  from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
@@ -43,12 +41,12 @@ from langchain.prompts.prompt import PromptTemplate
43
  from langchain.prompts.chat import MessagesPlaceholder
44
  from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
45
  from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
46
- from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage,\
47
  SystemMessage, ChatMessage, ToolMessage
48
  from langchain.memory import SQLChatMessageHistory
49
  from langchain.memory.chat_message_histories.sql import \
50
- BaseMessageConverter, DefaultMessageConverter
51
- from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
52
  # from langchain.agents.agent_toolkits import create_retriever_tool
53
  from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
54
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
@@ -73,7 +71,7 @@ UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API']
73
 
74
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
75
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
76
- (HumanMessagePromptTemplate, '{question}')])
77
  DEFAULT_SYSTEM_PROMPT = (
78
  "Do your best to answer the questions. "
79
  "Feel free to use any tools available to look up "
@@ -81,6 +79,7 @@ DEFAULT_SYSTEM_PROMPT = (
81
  "when calling search functions."
82
  )
83
 
 
84
  def hint_arxiv():
85
  st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
86
  "For example: \n\n"
@@ -150,7 +149,8 @@ sel_map = {
150
  "hint": hint_wiki,
151
  "hint_sql": hint_sql_wiki,
152
  "doc_prompt": PromptTemplate(
153
- input_variables=["page_content", "url", "title", "ref_id", "views"],
 
154
  template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
155
  "metadata_cols": [
156
  AttributeInfo(
@@ -224,6 +224,7 @@ sel_map = {
224
  }
225
  }
226
 
 
227
  def build_embedding_model(_sel):
228
  """Build embedding model
229
  """
@@ -253,7 +254,8 @@ def build_chains_retrievers(_sel: str) -> Dict[str, Any]:
253
  "sql_retriever": sql_retriever,
254
  "sql_chain": sql_chain
255
  }
256
-
 
257
  def build_self_query(_sel: str) -> SelfQueryRetriever:
258
  """Build self querying retriever
259
 
@@ -278,18 +280,20 @@ def build_self_query(_sel: str) -> SelfQueryRetriever:
278
  "vector": sel_map[_sel]["vector_col"],
279
  "metadata": sel_map[_sel]["metadata_col"]
280
  })
281
- doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
282
  must_have_cols=sel_map[_sel]['must_have_cols'])
283
 
284
  with st.spinner(f"Building Self Query Retriever for {_sel}..."):
285
  metadata_field_info = sel_map[_sel]["metadata_cols"]
286
  retriever = SelfQueryRetriever.from_llm(
287
- OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
 
288
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
289
  use_original_query=False, structured_query_translator=MyScaleTranslator())
290
  return retriever
291
 
292
- def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
 
293
  """Build Vector SQL Database Retriever
294
 
295
  :param _sel: selected knowledge base
@@ -308,7 +312,8 @@ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
308
  output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
309
  model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
310
  sql_query_chain = VectorSQLDatabaseChain.from_llm(
311
- llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
 
312
  prompt=PROMPT,
313
  top_k=10,
314
  return_direct=True,
@@ -319,8 +324,9 @@ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
319
  sql_retriever = VectorSQLDatabaseChainRetriever(
320
  sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
321
  return sql_retriever
322
-
323
- def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query") -> ArXivQAwithSourcesChain:
 
324
  """_summary_
325
 
326
  :param _sel: selected knowledge base
@@ -350,6 +356,7 @@ def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query")
350
  )
351
  return chain
352
 
 
353
  @st.cache_resource
354
  def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
355
  """build all resources
@@ -365,6 +372,7 @@ def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
365
  sel_map_obj[k] = build_chains_retrievers(k)
366
  return sel_map_obj, embeddings
367
 
 
368
  def create_message_model(table_name, DynamicBase): # type: ignore
369
  """
370
  Create a message model for a given table name.
@@ -397,6 +405,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
397
 
398
  return Message
399
 
 
400
  def _message_from_dict(message: dict) -> BaseMessage:
401
  _type = message["type"]
402
  if _type == "human":
@@ -417,6 +426,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
417
  else:
418
  raise ValueError(f"Got unexpected message type: {_type}")
419
 
 
420
  class DefaultClickhouseMessageConverter(DefaultMessageConverter):
421
  """The default message converter for SQLChatMessageHistory."""
422
 
@@ -425,27 +435,28 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
425
 
426
  def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
427
  tstamp = time.time()
428
- msg_id = hashlib.sha256(f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
 
429
  user_id, _ = session_id.split("?")
430
  return self.model_class(
431
- id=tstamp,
432
  msg_id=msg_id,
433
  user_id=user_id,
434
- session_id=session_id,
435
  type=message.type,
436
  addtionals=json.dumps(message.additional_kwargs),
437
  message=json.dumps({
438
- "type": message.type,
439
  "additional_kwargs": {"timestamp": tstamp},
440
  "data": message.dict()})
441
  )
442
-
443
  def from_sql_model(self, sql_message: Any) -> BaseMessage:
444
  msg_dump = json.loads(sql_message.message)
445
  msg = _message_from_dict(msg_dump)
446
  msg.additional_kwargs = msg_dump["additional_kwargs"]
447
  return msg
448
-
449
  def get_sql_model_class(self) -> Any:
450
  return self.model_class
451
 
@@ -458,7 +469,7 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
458
  connection_string=f'{conn_str}/chat?protocol=https',
459
  custom_message_converter=DefaultClickhouseMessageConverter(name))
460
  memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
461
-
462
  _system_message = SystemMessage(
463
  content=system_prompt
464
  )
@@ -475,10 +486,12 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
475
  return_intermediate_steps=True,
476
  **kwargs
477
  )
478
-
 
479
  class RetrieverInput(BaseModel):
480
  query: str = Field(description="query to look up in retriever")
481
 
 
482
  def create_retriever_tool(
483
  retriever: BaseRetriever, name: str, description: str
484
  ) -> Tool:
@@ -499,7 +512,7 @@ def create_retriever_tool(
499
  docs: List[Document] = func(*args, **kwargs)
500
  return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
501
  return wrapped_retrieve
502
-
503
  return Tool(
504
  name=name,
505
  description=description,
@@ -507,7 +520,8 @@ def create_retriever_tool(
507
  coroutine=retriever.aget_relevant_documents,
508
  args_schema=RetrieverInput,
509
  )
510
-
 
511
  @st.cache_resource
512
  def build_tools():
513
  """build all resources
@@ -531,8 +545,9 @@ def build_tools():
531
  })
532
  return sel_map_obj
533
 
 
534
  def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
535
- chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
536
  openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
537
  )
538
  tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
@@ -543,7 +558,7 @@ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temper
543
  chat_llm,
544
  tools=sel_tools,
545
  system_prompt=system_prompt
546
- )
547
  return agent
548
 
549
 
@@ -556,4 +571,4 @@ def display(dataframe, columns_=None, index=None):
556
  else:
557
  st.dataframe(dataframe)
558
  else:
559
- st.write("Sorry 😡 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
 
4
  import hashlib
5
  from typing import Dict, Any, List, Tuple
6
  import re
 
7
  from os import environ
8
  import streamlit as st
 
9
  from langchain.schema import BaseRetriever
10
  from langchain.tools import Tool
11
  from langchain.pydantic_v1 import BaseModel, Field
 
18
  from sqlalchemy.ext.declarative import declarative_base
19
  from sqlalchemy.orm import sessionmaker
20
  from clickhouse_sqlalchemy import (
21
+ types, engines
22
  )
23
  from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
24
  from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
 
41
  from langchain.prompts.chat import MessagesPlaceholder
42
  from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
43
  from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
44
+ from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage, \
45
  SystemMessage, ChatMessage, ToolMessage
46
  from langchain.memory import SQLChatMessageHistory
47
  from langchain.memory.chat_message_histories.sql import \
48
+ DefaultMessageConverter
49
+ from langchain.schema.messages import BaseMessage
50
  # from langchain.agents.agent_toolkits import create_retriever_tool
51
  from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
52
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
 
71
 
72
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
73
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
74
+ (HumanMessagePromptTemplate, '{question}')])
75
  DEFAULT_SYSTEM_PROMPT = (
76
  "Do your best to answer the questions. "
77
  "Feel free to use any tools available to look up "
 
79
  "when calling search functions."
80
  )
81
 
82
+
83
  def hint_arxiv():
84
  st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
85
  "For example: \n\n"
 
149
  "hint": hint_wiki,
150
  "hint_sql": hint_sql_wiki,
151
  "doc_prompt": PromptTemplate(
152
+ input_variables=["page_content",
153
+ "url", "title", "ref_id", "views"],
154
  template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
155
  "metadata_cols": [
156
  AttributeInfo(
 
224
  }
225
  }
226
 
227
+
228
  def build_embedding_model(_sel):
229
  """Build embedding model
230
  """
 
254
  "sql_retriever": sql_retriever,
255
  "sql_chain": sql_chain
256
  }
257
+
258
+
259
  def build_self_query(_sel: str) -> SelfQueryRetriever:
260
  """Build self querying retriever
261
 
 
280
  "vector": sel_map[_sel]["vector_col"],
281
  "metadata": sel_map[_sel]["metadata_col"]
282
  })
283
+ doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
284
  must_have_cols=sel_map[_sel]['must_have_cols'])
285
 
286
  with st.spinner(f"Building Self Query Retriever for {_sel}..."):
287
  metadata_field_info = sel_map[_sel]["metadata_cols"]
288
  retriever = SelfQueryRetriever.from_llm(
289
+ OpenAI(model_name=query_model_name,
290
+ openai_api_key=OPENAI_API_KEY, temperature=0),
291
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
292
  use_original_query=False, structured_query_translator=MyScaleTranslator())
293
  return retriever
294
 
295
+
296
+ def build_vector_sql(_sel: str) -> VectorSQLDatabaseChainRetriever:
297
  """Build Vector SQL Database Retriever
298
 
299
  :param _sel: selected knowledge base
 
312
  output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
313
  model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
314
  sql_query_chain = VectorSQLDatabaseChain.from_llm(
315
+ llm=OpenAI(model_name=query_model_name,
316
+ openai_api_key=OPENAI_API_KEY, temperature=0),
317
  prompt=PROMPT,
318
  top_k=10,
319
  return_direct=True,
 
324
  sql_retriever = VectorSQLDatabaseChainRetriever(
325
  sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
326
  return sql_retriever
327
+
328
+
329
+ def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str = "Self-query") -> ArXivQAwithSourcesChain:
330
  """_summary_
331
 
332
  :param _sel: selected knowledge base
 
356
  )
357
  return chain
358
 
359
+
360
  @st.cache_resource
361
  def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
362
  """build all resources
 
372
  sel_map_obj[k] = build_chains_retrievers(k)
373
  return sel_map_obj, embeddings
374
 
375
+
376
  def create_message_model(table_name, DynamicBase): # type: ignore
377
  """
378
  Create a message model for a given table name.
 
405
 
406
  return Message
407
 
408
+
409
  def _message_from_dict(message: dict) -> BaseMessage:
410
  _type = message["type"]
411
  if _type == "human":
 
426
  else:
427
  raise ValueError(f"Got unexpected message type: {_type}")
428
 
429
+
430
  class DefaultClickhouseMessageConverter(DefaultMessageConverter):
431
  """The default message converter for SQLChatMessageHistory."""
432
 
 
435
 
436
  def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
437
  tstamp = time.time()
438
+ msg_id = hashlib.sha256(
439
+ f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
440
  user_id, _ = session_id.split("?")
441
  return self.model_class(
442
+ id=tstamp,
443
  msg_id=msg_id,
444
  user_id=user_id,
445
+ session_id=session_id,
446
  type=message.type,
447
  addtionals=json.dumps(message.additional_kwargs),
448
  message=json.dumps({
449
+ "type": message.type,
450
  "additional_kwargs": {"timestamp": tstamp},
451
  "data": message.dict()})
452
  )
453
+
454
  def from_sql_model(self, sql_message: Any) -> BaseMessage:
455
  msg_dump = json.loads(sql_message.message)
456
  msg = _message_from_dict(msg_dump)
457
  msg.additional_kwargs = msg_dump["additional_kwargs"]
458
  return msg
459
+
460
  def get_sql_model_class(self) -> Any:
461
  return self.model_class
462
 
 
469
  connection_string=f'{conn_str}/chat?protocol=https',
470
  custom_message_converter=DefaultClickhouseMessageConverter(name))
471
  memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
472
+
473
  _system_message = SystemMessage(
474
  content=system_prompt
475
  )
 
486
  return_intermediate_steps=True,
487
  **kwargs
488
  )
489
+
490
+
491
  class RetrieverInput(BaseModel):
492
  query: str = Field(description="query to look up in retriever")
493
 
494
+
495
  def create_retriever_tool(
496
  retriever: BaseRetriever, name: str, description: str
497
  ) -> Tool:
 
512
  docs: List[Document] = func(*args, **kwargs)
513
  return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
514
  return wrapped_retrieve
515
+
516
  return Tool(
517
  name=name,
518
  description=description,
 
520
  coroutine=retriever.aget_relevant_documents,
521
  args_schema=RetrieverInput,
522
  )
523
+
524
+
525
  @st.cache_resource
526
  def build_tools():
527
  """build all resources
 
545
  })
546
  return sel_map_obj
547
 
548
+
549
  def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
550
+ chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
551
  openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
552
  )
553
  tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
 
558
  chat_llm,
559
  tools=sel_tools,
560
  system_prompt=system_prompt
561
+ )
562
  return agent
563
 
564
 
 
571
  else:
572
  st.dataframe(dataframe)
573
  else:
574
+ st.write("Sorry 😡 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
lib/json_conv.py CHANGED
@@ -1,15 +1,18 @@
1
  import json
2
  import datetime
3
 
 
4
  class CustomJSONEncoder(json.JSONEncoder):
5
  def default(self, obj):
6
  if isinstance(obj, datetime.datetime):
7
  return datetime.datetime.isoformat(obj)
8
  return json.JSONEncoder.default(self, obj)
9
 
 
10
  class CustomJSONDecoder(json.JSONDecoder):
11
  def __init__(self, *args, **kwargs):
12
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
 
13
 
14
  def object_hook(self, source):
15
  for k, v in source.items():
@@ -18,4 +21,4 @@ class CustomJSONDecoder(json.JSONDecoder):
18
  source[k] = datetime.datetime.fromisoformat(str(v))
19
  except:
20
  pass
21
- return source
 
1
  import json
2
  import datetime
3
 
4
+
5
  class CustomJSONEncoder(json.JSONEncoder):
6
  def default(self, obj):
7
  if isinstance(obj, datetime.datetime):
8
  return datetime.datetime.isoformat(obj)
9
  return json.JSONEncoder.default(self, obj)
10
 
11
+
12
  class CustomJSONDecoder(json.JSONDecoder):
13
  def __init__(self, *args, **kwargs):
14
+ json.JSONDecoder.__init__(
15
+ self, object_hook=self.object_hook, *args, **kwargs)
16
 
17
  def object_hook(self, source):
18
  for k, v in source.items():
 
21
  source[k] = datetime.datetime.fromisoformat(str(v))
22
  except:
23
  pass
24
+ return source
lib/private_kb.py CHANGED
@@ -52,7 +52,8 @@ def parse_files(api_key, user_id, files: List[UploadedFile]):
52
 
53
  def extract_embedding(embeddings: Embeddings, texts):
54
  if len(texts) > 0:
55
- embs = embeddings.embed_documents([t["text"] for _, t in enumerate(texts)])
 
56
  for i, _ in enumerate(texts):
57
  texts[i]["vector"] = embs[i]
58
  return texts
 
52
 
53
  def extract_embedding(embeddings: Embeddings, texts):
54
  if len(texts) > 0:
55
+ embs = embeddings.embed_documents(
56
+ [t["text"] for _, t in enumerate(texts)])
57
  for i, _ in enumerate(texts):
58
  texts[i]["vector"] = embs[i]
59
  return texts
lib/schemas.py CHANGED
@@ -49,4 +49,4 @@ def create_session_table(table_name, DynamicBase): # type: ignore
49
  order_by=('session_id')),
50
  {'comment': 'Store Session and Prompts'}
51
  )
52
- return Session
 
49
  order_by=('session_id')),
50
  {'comment': 'Store Session and Prompts'}
51
  )
52
+ return Session
lib/sessions.py CHANGED
@@ -6,9 +6,9 @@ except ImportError:
6
  from langchain.schema import BaseChatMessageHistory
7
  from datetime import datetime
8
  from sqlalchemy import Column, Text, orm, create_engine
9
- from clickhouse_sqlalchemy import types, engines
10
  from .schemas import create_message_model, create_session_table
11
 
 
12
  def get_sessions(engine, model_class, user_id):
13
  with orm.sessionmaker(engine)() as session:
14
  result = (
@@ -20,14 +20,17 @@ def get_sessions(engine, model_class, user_id):
20
  )
21
  return json.loads(result)
22
 
 
23
  class SessionManager:
24
  def __init__(self, session_state, host, port, username, password,
25
  db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
26
  conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
27
  self.engine = create_engine(conn_str, echo=False)
28
- self.sess_model_class = create_session_table(sess_table, declarative_base())
 
29
  self.sess_model_class.metadata.create_all(self.engine)
30
- self.msg_model_class = create_message_model(msg_table, declarative_base())
 
31
  self.msg_model_class.metadata.create_all(self.engine)
32
  self.Session = orm.sessionmaker(self.engine)
33
  self.session_state = session_state
@@ -46,14 +49,15 @@ class SessionManager:
46
  sessions.append({
47
  "session_id": r.session_id.split("?")[-1],
48
  "system_prompt": r.system_prompt,
49
- })
50
  return sessions
51
-
52
  def modify_system_prompt(self, session_id, sys_prompt):
53
  with self.Session() as session:
54
- session.update(self.sess_model_class).where(self.sess_model_class==session_id).value(system_prompt=sys_prompt)
 
55
  session.commit()
56
-
57
  def add_session(self, user_id, session_id, system_prompt, **kwargs):
58
  with self.Session() as session:
59
  elem = self.sess_model_class(
@@ -62,14 +66,13 @@ class SessionManager:
62
  )
63
  session.add(elem)
64
  session.commit()
65
-
66
  def remove_session(self, session_id):
67
  with self.Session() as session:
68
- session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete()
 
69
  # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
70
  if "agent" in self.session_state:
71
  self.session_state.agent.memory.chat_memory.clear()
72
  if "file_analyzer" in self.session_state:
73
  self.session_state.file_analyzer.clear_files()
74
-
75
-
 
6
  from langchain.schema import BaseChatMessageHistory
7
  from datetime import datetime
8
  from sqlalchemy import Column, Text, orm, create_engine
 
9
  from .schemas import create_message_model, create_session_table
10
 
11
+
12
  def get_sessions(engine, model_class, user_id):
13
  with orm.sessionmaker(engine)() as session:
14
  result = (
 
20
  )
21
  return json.loads(result)
22
 
23
+
24
  class SessionManager:
25
  def __init__(self, session_state, host, port, username, password,
26
  db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
27
  conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
28
  self.engine = create_engine(conn_str, echo=False)
29
+ self.sess_model_class = create_session_table(
30
+ sess_table, declarative_base())
31
  self.sess_model_class.metadata.create_all(self.engine)
32
+ self.msg_model_class = create_message_model(
33
+ msg_table, declarative_base())
34
  self.msg_model_class.metadata.create_all(self.engine)
35
  self.Session = orm.sessionmaker(self.engine)
36
  self.session_state = session_state
 
49
  sessions.append({
50
  "session_id": r.session_id.split("?")[-1],
51
  "system_prompt": r.system_prompt,
52
+ })
53
  return sessions
54
+
55
  def modify_system_prompt(self, session_id, sys_prompt):
56
  with self.Session() as session:
57
+ session.update(self.sess_model_class).where(
58
+ self.sess_model_class == session_id).value(system_prompt=sys_prompt)
59
  session.commit()
60
+
61
  def add_session(self, user_id, session_id, system_prompt, **kwargs):
62
  with self.Session() as session:
63
  elem = self.sess_model_class(
 
66
  )
67
  session.add(elem)
68
  session.commit()
69
+
70
  def remove_session(self, session_id):
71
  with self.Session() as session:
72
+ session.query(self.sess_model_class).where(
73
+ self.sess_model_class.session_id == session_id).delete()
74
  # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
75
  if "agent" in self.session_state:
76
  self.session_state.agent.memory.chat_memory.clear()
77
  if "file_analyzer" in self.session_state:
78
  self.session_state.file_analyzer.clear_files()
 
 
login.py CHANGED
@@ -1,21 +1,21 @@
1
- import json
2
- import time
3
- import pandas as pd
4
- from os import environ
5
  import streamlit as st
6
  from auth0_component import login_button
7
 
8
  AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
9
  AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
10
 
 
11
  def login():
12
  if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
13
  return True
14
- st.subheader("πŸ€— Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! πŸ€— ")
 
15
  st.write("You can now chat with ArXiv and Wikipedia! 🌟\n")
16
  st.write("Built purely with streamlit πŸ‘‘ , LangChain πŸ¦œπŸ”— and love ❀️ for AI!")
17
- st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
18
- st.write("For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
 
 
19
  st.divider()
20
  col1, col2 = st.columns(2, gap='large')
21
  with col1.container():
@@ -33,7 +33,7 @@ def login():
33
  st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
34
  "- [Terms of Sevice](https://myscale.com/terms/)")
35
  if st.session_state.auth0 is not None:
36
- st.session_state.user_info = dict(st.session_state.auth0)
37
  if 'email' in st.session_state.user_info:
38
  email = st.session_state.user_info["email"]
39
  else:
@@ -44,6 +44,7 @@ def login():
44
  if st.session_state.jump_query_ask:
45
  st.experimental_rerun()
46
 
 
47
  def back_to_main():
48
  if "user_info" in st.session_state:
49
  del st.session_state.user_info
 
 
 
 
 
1
  import streamlit as st
2
  from auth0_component import login_button
3
 
4
  AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
5
  AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
6
 
7
+
8
  def login():
9
  if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
10
  return True
11
+ st.subheader(
12
+ "πŸ€— Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! πŸ€— ")
13
  st.write("You can now chat with ArXiv and Wikipedia! 🌟\n")
14
  st.write("Built purely with streamlit πŸ‘‘ , LangChain πŸ¦œπŸ”— and love ❀️ for AI!")
15
+ st.write(
16
+ "Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
17
+ st.write(
18
+ "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
19
  st.divider()
20
  col1, col2 = st.columns(2, gap='large')
21
  with col1.container():
 
33
  st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
34
  "- [Terms of Sevice](https://myscale.com/terms/)")
35
  if st.session_state.auth0 is not None:
36
+ st.session_state.user_info = dict(st.session_state.auth0)
37
  if 'email' in st.session_state.user_info:
38
  email = st.session_state.user_info["email"]
39
  else:
 
44
  if st.session_state.jump_query_ask:
45
  st.experimental_rerun()
46
 
47
+
48
  def back_to_main():
49
  if "user_info" in st.session_state:
50
  del st.session_state.user_info
prompts/arxiv_prompt.py CHANGED
@@ -6,7 +6,7 @@ combine_prompt_template = (
6
  + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
7
  + "corresponding section name and page that you refer to when answering. The following is the related information "
8
  + "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
9
- + "Now you should anwser user's question. Remember you must use `Doc #` to refer papers:\n\n"
10
  )
11
 
12
  _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
 
6
  + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
7
  + "corresponding section name and page that you refer to when answering. The following is the related information "
8
  + "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
9
+ + "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
10
  )
11
 
12
  _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.