ivnban27-ctl commited on
Commit
5b74371
1 Parent(s): a59350a

conversation end

Browse files
app_config.py CHANGED
@@ -9,7 +9,7 @@ SOURCES = [
9
  'OA_rolemodel',
10
  # 'OA_finetuned',
11
  ]
12
- SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
13
  "OA_finetuned":'Finetuned OpenAI',
14
  # "CTL_llama2": "Llama 2",
15
  "CTL_llama3": "Llama 3",
@@ -29,10 +29,13 @@ def source2label(source):
29
  def issue2label(issue):
30
  return seed2str.get(issue, "GCT")
31
 
32
- ENVIRON = "dev"
33
 
34
  DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
35
  DB_CONVOS = 'conversations'
36
  DB_COMPLETIONS = 'comparison_completions'
37
  DB_BATTLES = 'battles'
38
- DB_ERRORS = 'completion_errors'
 
 
 
 
9
  'OA_rolemodel',
10
  # 'OA_finetuned',
11
  ]
12
+ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
13
  "OA_finetuned":'Finetuned OpenAI',
14
  # "CTL_llama2": "Llama 2",
15
  "CTL_llama3": "Llama 3",
 
29
  def issue2label(issue):
30
  return seed2str.get(issue, "GCT")
31
 
32
+ ENVIRON = "prod"
33
 
34
  DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
35
  DB_CONVOS = 'conversations'
36
  DB_COMPLETIONS = 'comparison_completions'
37
  DB_BATTLES = 'battles'
38
+ DB_ERRORS = 'completion_errors'
39
+
40
+ MAX_MSG_COUNT = 10
41
+ WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
convosim.py CHANGED
@@ -6,7 +6,7 @@ from utils.mongo_utils import get_db_client
6
  from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
- from app_config import ISSUES, SOURCES, source2label, issue2label
10
 
11
  logger = get_logger(__name__)
12
  openai_api_key = os.environ['OPENAI_API_KEY']
@@ -15,6 +15,8 @@ temperature = 0.8
15
 
16
  if "sent_messages" not in st.session_state:
17
  st.session_state['sent_messages'] = 0
 
 
18
  if "issue" not in st.session_state:
19
  st.session_state['issue'] = ISSUES[0]
20
  if 'previous_source' not in st.session_state:
@@ -57,6 +59,7 @@ if changed_source:
57
  st.session_state['previous_source'] = source
58
  st.session_state['issue'] = issue
59
  st.session_state['sent_messages'] = 0
 
60
  create_memory_add_initial_message(memories,
61
  issue,
62
  language,
@@ -69,12 +72,12 @@ memoryA = st.session_state[list(memories.keys())[0]]
69
  llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
70
 
71
  st.title("💬 Simulator")
72
-
73
  for msg in memoryA.buffer_as_messages:
74
  role = "user" if type(msg) == HumanMessage else "assistant"
75
  st.chat_message(role).write(msg.content)
76
 
77
- if prompt := st.chat_input():
78
  st.session_state['sent_messages'] += 1
79
  st.chat_message("user").write(prompt)
80
  if 'convo_id' not in st.session_state:
@@ -85,6 +88,12 @@ if prompt := st.chat_input():
85
  for response in responses:
86
  st.chat_message("assistant").write(response)
87
 
 
 
 
 
 
 
88
  with st.sidebar:
89
  st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
90
- st.markdown(f"### Total Messages: :red[**{len(memoryA.chat_memory.messages)}**]")
 
6
  from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
+ from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
 
11
  logger = get_logger(__name__)
12
  openai_api_key = os.environ['OPENAI_API_KEY']
 
15
 
16
  if "sent_messages" not in st.session_state:
17
  st.session_state['sent_messages'] = 0
18
+ if "total_messages" not in st.session_state:
19
+ st.session_state['total_messages'] = 0
20
  if "issue" not in st.session_state:
21
  st.session_state['issue'] = ISSUES[0]
22
  if 'previous_source' not in st.session_state:
 
59
  st.session_state['previous_source'] = source
60
  st.session_state['issue'] = issue
61
  st.session_state['sent_messages'] = 0
62
+ st.session_state['total_messages'] = 0
63
  create_memory_add_initial_message(memories,
64
  issue,
65
  language,
 
72
  llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
73
 
74
  st.title("💬 Simulator")
75
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
76
  for msg in memoryA.buffer_as_messages:
77
  role = "user" if type(msg) == HumanMessage else "assistant"
78
  st.chat_message(role).write(msg.content)
79
 
80
+ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
81
  st.session_state['sent_messages'] += 1
82
  st.chat_message("user").write(prompt)
83
  if 'convo_id' not in st.session_state:
 
88
  for response in responses:
89
  st.chat_message("assistant").write(response)
90
 
91
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
92
+ if st.session_state['total_messages'] >= MAX_MSG_COUNT:
93
+ st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
94
+ elif st.session_state['total_messages'] >= WARN_MSG_COUT:
95
+ st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
96
+
97
  with st.sidebar:
98
  st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
99
+ st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
models/databricks/texter_sim_llm.py CHANGED
@@ -24,8 +24,8 @@ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texte
24
  )
25
 
26
  llm = CustomDatabricksLLM(
27
- endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
28
- # endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
29
  bearer_token=os.environ["DATABRICKS_TOKEN"],
30
  texter_name=texter_name,
31
  issue=issue,
 
24
  )
25
 
26
  llm = CustomDatabricksLLM(
27
+ # endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
28
+ endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
29
  bearer_token=os.environ["DATABRICKS_TOKEN"],
30
  texter_name=texter_name,
31
  issue=issue,
requirements.txt CHANGED
@@ -2,4 +2,5 @@ scipy==1.11.1
2
  langchain==0.3.0
3
  pymongo==4.5.0
4
  mlflow==2.9.0
5
- langchain-openai==0.2.0
 
 
2
  langchain==0.3.0
3
  pymongo==4.5.0
4
  mlflow==2.9.0
5
+ langchain-openai==0.2.0
6
+ streamlit==1.38.0