File size: 2,688 Bytes
975a927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
import langchain
from langchain.memory import ConversationBufferMemory

from app_config import ENVIRON
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
from models.openai.role_models import get_role_chain, role_templates
from mongo_utils import new_convo

langchain.verbose = ENVIRON=="dev"
logger = get_logger(__name__)

def add_initial_message(model_name, memory):
    if "Spanish" in model_name:
        memory.chat_memory.add_ai_message("Hola necesito ayuda") 
    else:
        memory.chat_memory.add_ai_message("Hi I need help")


def push_convo2db(memories, username, language):
    if len(memories) == 1:
        issue =  memories['memory']['issue']
        model_one = memories['memory']['source']
        new_convo(st.session_state['db_client'], issue, language, username, False, model_one)
    else:
        issue =  memories['commonMemory']['issue']
        model_one = memories['memoryA']['source']
        model_two = memories['memoryB']['source']
        new_convo(st.session_state['db_client'], issue, language, username, True, model_one, model_two)

def change_memories(memories, username, language, changed_source=False):
    for memory, params in memories.items():
        if (memory not in st.session_state) or changed_source:
            source = params['source']
            logger.info(f"Source for memory {memory} is {source}")
            if source in ('OA_rolemodel','OA_finetuned'):
                st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
    
    if ("convo_id" in st.session_state) and changed_source:
        del st.session_state['convo_id']
    

def clear_memory(memories, username, language):
    for memory, _ in memories.items():
        st.session_state[memory].clear()

    if "convo_id" in st.session_state:
        del st.session_state['convo_id']
    

def create_memory_add_initial_message(memories, username, language, changed_source=False):
    change_memories(memories, username, language, changed_source=changed_source)
    for memory, _ in memories.items():
        if len(st.session_state[memory].buffer_as_messages) < 1:
            add_initial_message(language, st.session_state[memory])


def get_chain(issue, language, source, memory, temperature):
    if source in ("OA_finetuned"):
        OA_engine = finetuned_models[f"{issue}-{language}"]
        return get_finetuned_chain(OA_engine, memory, temperature)
    elif source in ('OA_rolemodel'):
        template = role_templates[f"{issue}-{language}"]
        return get_role_chain(template, memory, temperature)