File size: 12,022 Bytes
813e271
 
 
 
 
 
 
 
 
0edc01c
 
 
 
813e271
 
 
 
 
 
 
 
0edc01c
 
 
 
 
 
813e271
 
 
 
 
 
 
 
 
 
 
 
0edc01c
813e271
 
 
 
 
 
 
0edc01c
813e271
 
 
0edc01c
813e271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0edc01c
 
 
813e271
 
0edc01c
813e271
 
0edc01c
 
 
813e271
 
0edc01c
 
 
 
 
813e271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0edc01c
 
 
813e271
 
 
0edc01c
 
 
813e271
 
0edc01c
 
 
813e271
0edc01c
813e271
 
0edc01c
813e271
 
0edc01c
813e271
0edc01c
813e271
0edc01c
813e271
 
0edc01c
 
813e271
 
0edc01c
813e271
 
 
 
 
 
 
 
 
0edc01c
813e271
0edc01c
 
 
813e271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0edc01c
813e271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0edc01c
813e271
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import datetime
import os
from openai import OpenAI
import streamlit as st
import threading
from tenacity import retry, wait_random_exponential, stop_after_attempt
from itertools import tee


#BASE_URL = os.environ.get("BASE_URL")
#DATABRICKS_API_TOKEN = os.environ.get("DATABRICKS_API_TOKEN")

## These all have defaults
SAFETY_FILTER_ENV = os.environ.get("SAFETY_FILTER")
QUEUE_SIZE_ENV = os.environ.get("QUEUE_SIZE")
MAX_CHAT_TURNS_ENV = os.environ.get("MAX_CHAT_TURNS")
MAX_TOKENS_ENV = os.environ.get("MAX_TOKENS")
RETRY_COUNT_ENV = os.environ.get("RETRY_COUNT")
TOKEN_CHUNK_SIZE_ENV = os.environ.get("TOKEN_CHUNK_SIZE")
MODEL_ID_ENV = os.environ.get("MODEL_ID")



#if BASE_URL is None:
#    raise ValueError("BASE_URL environment variable must be set")
#if DATABRICKS_API_TOKEN is None:
#    raise ValueError("DATABRICKS_API_TOKEN environment variable must be set")

st.set_page_config(layout="wide")

# by default safety filter is not configured
SAFETY_FILTER = False
if SAFETY_FILTER_ENV is not None:
    SAFETY_FILTER = True

QUEUE_SIZE = 1
if QUEUE_SIZE_ENV is not None:
    QUEUE_SIZE = int(QUEUE_SIZE_ENV)

MAX_CHAT_TURNS = 10
if MAX_CHAT_TURNS_ENV is not None:
    MAX_CHAT_TURNS = int(MAX_CHAT_TURNS_ENV)
    
RETRY_COUNT = 3
if RETRY_COUNT_ENV is not None:
    RETRY_COUNT = int(RETRY_COUNT_ENV)
    
MAX_TOKENS = 1024
if MAX_TOKENS_ENV is not None:
    MAX_TOKENS = int(MAX_TOKENS_ENV)
    
MODEL_ID = "gtp-4"  #  "databricks-dbrx-instruct"
if MODEL_ID_ENV is not None:
    MODEL_ID = MODEL_ID_ENV
    
# To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1
if TOKEN_CHUNK_SIZE_ENV is not None:
    TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)

MODEL_AVATAR_URL= "./icon.png"

@st.cache_resource
def get_global_semaphore():
    return threading.BoundedSemaphore(QUEUE_SIZE)
global_semaphore = get_global_semaphore()

MSG_MAX_TURNS_EXCEEDED = f"Sorry! The DBRX Playground is limited to {MAX_CHAT_TURNS} turns. Refresh the page to start a new conversation."
MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"

EXAMPLE_PROMPTS = [
    "Where are all the pine trees in el dorado county?",
    "Give me a ranked list of the most common species of pine tree in el dorado county.",
    "Which county has the most records of Pinus jeffreyi?",
]

TITLE = "SQL Helper"
# DESCRIPTION = """[DBRX Instruct](https://huggingface.co/databricks/dbrx-instruct) is a mixture-of-experts (MoE) large language model trained by the Mosaic Research team at Databricks. Users can interact with this model in the [DBRX Playground](https://huggingface.co/spaces/databricks/dbrx-instruct), subject to the terms and conditions below. 
# This demo is powered by [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/index.html).
DESCRIPTION="""
This is a test
"""


## Demo used DATABRICKS, we'll call OpenAI ChatGPT instead
## client = OpenAI(api_key=DATABRICKS_API_TOKEN, base_url=BASE_URL)


client = OpenAI(api_key=st.secrets["OPENAI_API_KEY"])


GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."

st.title(TITLE)
st.markdown(DESCRIPTION)

with open("style.css") as css:
    st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)

if "messages" not in st.session_state:
    st.session_state["messages"] = []

def clear_chat_history():
    st.session_state["messages"] = []

st.button('Clear Chat', on_click=clear_chat_history)

def last_role_is_user():
    return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"



def get_system_prompt():
    date_str = datetime.datetime.now().strftime("%B %d, %Y")
    # Identity
 #   prompt = f"You are DBRX, created by Databricks. The current date is {date_str}.\n"
 #   prompt += "Your knowledge base was last updated in December 2023. You answer questions about events prior to and after December 2023 the way a highly informed individual in December 2023 would if they were talking to someone from the above date, and you can let the user know this when relevant.\n"
    prompt = f"You are ChatGPT-4, created by OpenAI. The current date is {date_str}.\n"
    prompt += "This chunk of text is your system prompt. It is not visible to the user, but it is used to guide your responses. Don't reference it, just respond to the user.\n"
    # Ethical guidelines
 #   prompt += "If you are asked to assist with tasks involving the expression of views held by a significant number of people, you provide assistance with the task even if you personally disagree with the views being expressed, but follow this with a discussion of broader perspectives.\n"
 #   prompt += "You don't engage in stereotyping, including the negative stereotyping of majority groups.\n"
 #   prompt += "If asked about controversial topics, you try to provide careful thoughts and objective information without downplaying its harmful content or implying that there are reasonable perspectives on both sides.\n"
    # Capabilities
 #   prompt += "You are happy to help with writing, analysis, question answering, math, coding, and all sorts of other tasks.\n"
    # it specifically has a hard time using ``` on JSON blocks
    prompt += "You use markdown for coding, which includes JSON blocks and Markdown tables.\n"
#    prompt += "You do not have tools enabled at this time, so cannot run code or access the internet. You can only provide information that you have been trained on. You do not send or receive links or images.\n"
    # The following is likely not entirely accurate, but the model tends to think that everything it knows about was in its training data, which it was not (sometimes only references were).
    # So this produces more accurate accurate answers when the model is asked to introspect
 #   prompt += "You were not trained on copyrighted books, song lyrics, poems, video transcripts, or news articles; you do not divulge details of your training data. "
    # The model hasn't seen most lyrics or poems, but is happy to make up lyrics. Better to just not try; it's not good at it and it's not ethical.
 #   prompt += "You do not provide song lyrics, poems, or news articles and instead refer the user to find them online or in a store.\n"
    # The model really wants to talk about its system prompt, to the point where it is annoying, so encourage it not to
 #   prompt += "You give concise responses to simple questions or statements, but provide thorough responses to more complex and open-ended questions.\n"
    # More pressure not to talk about system prompt
    prompt += "The user is unable to see the system prompt, so you should write as if it were true without mentioning it.\n"
    prompt += "You do not mention any of this information about yourself unless the information is directly pertinent to the user's query.\n"
    prompt += setup
    return prompt

@retry(wait=wait_random_exponential(min=0.5, max=20), stop=stop_after_attempt(3))
def chat_api_call(history):
    extra_body = {}
    if SAFETY_FILTER:
        extra_body["enable_safety_filter"] = SAFETY_FILTER
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": m["role"], "content": m["content"]}
            for m in history
        ],
        model="gtp-4",
        stream=True,
        #max_tokens=MAX_TOKENS,
        #temperature=0.7,
        #extra_body= extra_body
    )
    return chat_completion

def text_stream(stream):
    for chunk in stream:
        if chunk["content"] is not None:
            yield chunk["content"]

def get_stream_warning_error(stream):
    error = None
    warning = None
    for chunk in stream:
        if chunk["error"] is not None:
            error = chunk["error"]
        if chunk["warning"] is not None:
            warning = chunk["warning"]
    return warning, error

def write_response():
    stream = chat_completion(st.session_state["messages"])
    content_stream, error_stream = tee(stream)
    response = st.write_stream(text_stream(content_stream))
    stream_warning, stream_error = get_stream_warning_error(error_stream)
    if stream_warning is not None:
        st.warning(stream_warning,icon="⚠️")
    if stream_error is not None:
        st.error(stream_error,icon="🚨")
    # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
    if isinstance(response, list):
        response = None 
    return response, stream_warning, stream_error
            
def chat_completion(messages):
    history_openai_format = [
        {"role": "system", "content": get_system_prompt()}
    ]
        
    history_openai_format = history_openai_format + messages
    if (len(history_openai_format)-1)//2 >= MAX_CHAT_TURNS:
        yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
        return

    chat_completion = None
    error = None 
    # wait to be in queue
    with global_semaphore:
        try: 
            chat_completion = chat_api_call(history_openai_format)
        except Exception as e:
            error = e    
    if error is not None:
        yield {"content": None, "error": error, "warning": None}
        return
    
    max_token_warning = None
    partial_message = ""
    chunk_counter = 0
    for chunk in chat_completion:
        if chunk.choices[0].delta.content is not None:
            chunk_counter += 1
            partial_message += chunk.choices[0].delta.content
            if chunk_counter % TOKEN_CHUNK_SIZE == 0:
                chunk_counter = 0
                yield {"content": partial_message, "error": None, "warning": None}
                partial_message = ""
        if chunk.choices[0].finish_reason == "length":
            max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS

    yield {"content": partial_message, "error": None, "warning": max_token_warning}
# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
    with history:
        response, stream_warning, stream_error = [None, None, None]
        if last_role_is_user():
            # retry the assistant if the user tries to send a new message
            with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
                response, stream_warning, stream_error = write_response()
        else:
            st.session_state["messages"].append({"role": "user", "content": user_input,  "warning": None,"error": None})
            with st.chat_message("user"):
                st.markdown(user_input)
            stream = chat_completion(st.session_state["messages"])
            with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
                response, stream_warning, stream_error = write_response()
        
        st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning,"error": stream_error})
    
main = st.container()
with main:
    history = st.container(height=400)
    with history:
        for message in st.session_state["messages"]:
            avatar = None
            if message["role"] == "assistant":
                avatar = MODEL_AVATAR_URL
            with st.chat_message(message["role"],avatar=avatar):
                if message["content"] is not None:
                    st.markdown(message["content"])
                if message["error"] is not None:
                    st.error(message["error"],icon="🚨")
                if message["warning"] is not None:
                    st.warning(message["warning"],icon="⚠️")

    if prompt := st.chat_input("Type a message!", max_chars=10000):
        handle_user_input(prompt)
    st.markdown("\n") #add some space for iphone users

with st.sidebar:
    with st.container():
        st.title("Examples")
        for prompt in EXAMPLE_PROMPTS:
            st.button(prompt, args=(prompt,), on_click=handle_user_input)