import datetime as dt import streamlit as st from streamlit.logger import get_logger import langchain from langchain.memory import ConversationBufferMemory from app_config import ENVIRON from models.openai.finetuned_models import finetuned_models, get_finetuned_chain from models.openai.role_models import get_role_chain, role_templates from models.databricks.scenario_sim_biz import get_databricks_chain from mongo_utils import new_convo langchain.verbose = ENVIRON=="dev" logger = get_logger(__name__) def add_initial_message(model_name, memory): if "Spanish" in model_name: memory.chat_memory.add_ai_message("Hola necesito ayuda") else: memory.chat_memory.add_ai_message("Hi I need help") def push_convo2db(memories, username, language): if len(memories) == 1: issue = memories['memory']['issue'] model_one = memories['memory']['source'] new_convo(st.session_state['db_client'], issue, language, username, False, model_one) else: issue = memories['commonMemory']['issue'] model_one = memories['memoryA']['source'] model_two = memories['memoryB']['source'] new_convo(st.session_state['db_client'], issue, language, username, True, model_one, model_two) def change_memories(memories, username, language, changed_source=False): for memory, params in memories.items(): if (memory not in st.session_state) or changed_source: source = params['source'] logger.info(f"Source for memory {memory} is {source}") if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"): st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper') if ("convo_id" in st.session_state) and changed_source: del st.session_state['convo_id'] def clear_memory(memories, username, language): for memory, _ in memories.items(): st.session_state[memory].clear() if "convo_id" in st.session_state: del st.session_state['convo_id'] def create_memory_add_initial_message(memories, username, language, changed_source=False): change_memories(memories, username, language, changed_source=changed_source) for memory, _ in memories.items(): if len(st.session_state[memory].buffer_as_messages) < 1: add_initial_message(language, st.session_state[memory]) def get_chain(issue, language, source, memory, temperature): if source in ("OA_finetuned"): OA_engine = finetuned_models[f"{issue}-{language}"] return get_finetuned_chain(OA_engine, memory, temperature) elif source in ('OA_rolemodel'): template = role_templates[f"{issue}-{language}"] return get_role_chain(template, memory, temperature) elif source in ('CTL_llama2'): if language == "English": language = "en" elif language == "Spanish": language = "es" return get_databricks_chain(issue, language, memory, temperature)