File size: 6,373 Bytes
9814d4c
a8a9ff0
 
 
9814d4c
a8a9ff0
7081223
 
 
 
 
a8a9ff0
 
b5792ea
1927487
a8a9ff0
 
 
732d634
b5792ea
7081223
b5792ea
 
9814d4c
732d634
 
7081223
 
f2f3156
 
 
08b9e09
 
 
7081223
 
08b9e09
 
 
 
 
7081223
08b9e09
 
7081223
 
08b9e09
7081223
 
 
 
08b9e09
 
7081223
 
f2f3156
08b9e09
7081223
 
 
 
 
f2f3156
 
 
7081223
 
 
 
 
 
 
 
a8a9ff0
9814d4c
08b9e09
a8a9ff0
ee12bcf
a8a9ff0
 
7081223
 
 
ee12bcf
1927487
9814d4c
a8a9ff0
7081223
 
 
 
 
 
a8a9ff0
9814d4c
 
08b9e09
a8a9ff0
 
7081223
a8a9ff0
f2f3156
7081223
 
 
9814d4c
7081223
 
08b9e09
f2f3156
7081223
 
 
 
 
 
 
 
08b9e09
f2f3156
 
 
7081223
f2f3156
 
7081223
 
f2f3156
 
 
 
 
 
 
 
 
7081223
 
 
 
08b9e09
 
7081223
 
 
 
 
 
08b9e09
 
7081223
08b9e09
7081223
f2f3156
 
7081223
f2f3156
 
 
 
 
7081223
08b9e09
 
 
 
f2f3156
08b9e09
f2f3156
 
 
 
 
 
 
 
08b9e09
 
 
f2f3156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import logging

from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFaceHub
from langchain.prompts.chat import (
    PromptTemplate,
    ChatPromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from openai.error import AuthenticationError
import streamlit as st


def setup_memory():
    msgs = StreamlitChatMessageHistory(key="basic_chat_app")
    memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
                                            chat_memory=msgs,
                                            return_messages=True)
    logging.info("setting up new chat memory")
    return memory


def use_existing_chain(model, provider, model_kwargs):
    # TODO: consider whether prompt needs to be checked here
    if "mistral" in model:
        return False
    if "current_chain" in st.session_state:
        current_chain = st.session_state.current_chain
        if (current_chain.model == model) \
                and (current_chain.provider == provider) \
                and (current_chain.model_kwargs == model_kwargs):
            return True
    return False


class CurrentChain():
    def __init__(self, model, provider, prompt, memory, model_kwargs):
        self.model = model
        self.provider = provider
        self.model_kwargs = model_kwargs

        logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
        if provider == "OpenAI":
            llm = ChatOpenAI(model_name=model,
                             temperature=model_kwargs['temperature']
                             )
        elif provider == "HuggingFace":
            llm = HuggingFaceHub(repo_id=model,
                                 model_kwargs=model_kwargs
                                 )

        self.conversation = LLMChain(
            llm=llm,
            prompt=prompt,
            verbose=True,
            memory=memory
        )


def format_mistral_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    st.header("Basic chatbot")
    st.write("On small screens, click the `>` at top left to choose options")
    with st.expander("How conversation history works"):
        st.write("To keep input lengths down and costs reasonable,"
                 " only the past three turns of conversation "
                 " are used for OpenAI models. Otherwise the entire chat history is used.")
        st.write("To clear all memory and start fresh, click 'Clear history'")
    st.sidebar.title("Choose options")

    #### USER INPUT ######
    model_name = st.sidebar.selectbox(
        label="Choose a model",
        options=["gpt-3.5-turbo (OpenAI)",
                 # "bigscience/bloom (HuggingFace)",  # runs
                 # "google/flan-t5-xxl (HuggingFace)",  # runs
                 "mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
                 ],
        help="Which LLM to use",
    )

    temp = st.sidebar.slider(
        label="Temperature",
        min_value=float(0),
        max_value=2.0,
        step=0.1,
        value=0.4,
        help="Set the decoding temperature. "
             "Higher temps give more unpredictable outputs."
    )
    ##########################

    model = model_name.split("(")[0].rstrip()  # remove name of model provider
    provider = model_name.split("(")[-1].split(")")[0]

    model_kwargs = {"temperature": temp,
                    "max_new_tokens": 256,
                    "repetition_penalty": 1.0,
                    "top_p": 0.95,
                    "do_sample": True,
                    "seed": 42}
    # TODO: maybe expose more of these to the user

    if "session_memory" not in st.session_state:
        st.session_state.session_memory = setup_memory()  # for openai

    if "history" not in st.session_state:
        st.session_state.history = []  # for mistral

    if "mistral" in model:
        prompt = PromptTemplate(input_variables=["input"],
                                template="{input}")
    else:
        prompt = ChatPromptTemplate(
            messages=[
                SystemMessagePromptTemplate.from_template(
                    "You are a nice chatbot having a conversation with a human."
                ),
                MessagesPlaceholder(variable_name="chat_history"),
                HumanMessagePromptTemplate.from_template("{input}")
            ],
            verbose=True
        )

    if use_existing_chain(model, provider, model_kwargs):
        chain = st.session_state.current_chain
    else:
        chain = CurrentChain(model,
                             provider,
                             prompt,
                             st.session_state.session_memory,
                             model_kwargs)
        st.session_state.current_chain = chain

    conversation = chain.conversation

    if st.button("Clear history"):
        conversation.memory.clear()  # for openai
        st.session_state.history = []  # for mistral
        logging.info("history cleared")

    for user_msg, asst_msg in st.session_state.history:
        with st.chat_message("user"):
            st.write(user_msg)
        with st.chat_message("assistant"):
            st.write(asst_msg)

    text = st.chat_input()
    if text:
        with st.chat_message("user"):
            st.write(text)
            logging.info(text)
        try:
            if "mistral" in model:
                full_prompt = format_mistral_prompt(text, st.session_state.history)
                result = conversation.predict(input=full_prompt)
            else:
                result = conversation.predict(input=text)

            st.session_state.history.append((text, result))
            logging.info(repr(result))
            with st.chat_message("assistant"):
                st.write(result)
        except (AuthenticationError, ValueError):
            st.warning("Supply a valid API key", icon="⚠️")