import gradio as gr
from openai import OpenAI
import time
import html
def predict(message, history, character, api_key, progress=gr.Progress()):
client = OpenAI(api_key=api_key)
history_openai_format = []
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
response = client.chat.completions.create(
model='gpt-4',
messages=history_openai_format,
temperature=1.0,
stream=True
)
partial_message = ""
for chunk in progress.tqdm(response, desc="Generating"):
if chunk.choices[0].delta.content:
partial_message += chunk.choices[0].delta.content
yield partial_message
time.sleep(0.01)
def format_history(history):
html_content = ""
for human, ai in history:
human_formatted = html.escape(human).replace('\n', '
')
html_content += f'
You: {human_formatted}
'
if ai:
ai_formatted = html.escape(ai).replace('\n', '
')
html_content += f'AI: {ai_formatted}
'
return html_content
css = """
#chat-display {
height: 600px;
overflow-y: auto;
border: 1px solid #ccc;
padding: 10px;
margin-bottom: 10px;
}
#chat-display::-webkit-scrollbar {
width: 10px;
}
#chat-display::-webkit-scrollbar-track {
background: #f1f1f1;
}
#chat-display::-webkit-scrollbar-thumb {
background: #888;
}
#chat-display::-webkit-scrollbar-thumb:hover {
background: #555;
}
.message {
margin-bottom: 10px;
max-height: 300px;
overflow-y: auto;
word-wrap: break-word;
}
.user-message {
background-color: #e6f3ff;
padding: 5px;
border-radius: 5px;
}
.ai-message {
background-color: #f0f0f0;
padding: 5px;
border-radius: 5px;
}
"""
js = """
function maintainScroll(element_id) {
let element = document.getElementById(element_id);
let shouldScroll = element.scrollTop + element.clientHeight === element.scrollHeight;
let previousScrollTop = element.scrollTop;
return function() {
if (!shouldScroll) {
element.scrollTop = previousScrollTop;
} else {
element.scrollTop = element.scrollHeight;
}
}
}
let scrollMaintainer = maintainScroll('chat-display');
setInterval(scrollMaintainer, 100);
// Add event listener for Ctrl+Enter and prevent default Enter behavior
document.addEventListener('DOMContentLoaded', (event) => {
const textbox = document.querySelector('#your_message textarea');
textbox.addEventListener('keydown', function(e) {
if (e.ctrlKey && e.key === 'Enter') {
e.preventDefault();
document.querySelector('#your_message button').click();
} else if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
const start = this.selectionStart;
const end = this.selectionEnd;
this.value = this.value.substring(0, start) + "\\n" + this.value.substring(end);
this.selectionStart = this.selectionEnd = start + 1;
}
});
});
"""
with gr.Blocks(css=css, js=js) as demo:
gr.Markdown("My Chatbot
")
chat_history = gr.State([])
chat_display = gr.HTML(elem_id="chat-display")
msg = gr.Textbox(
label="Your message",
lines=2,
max_lines=10,
placeholder="Type your message here... (Press Ctrl+Enter to send, Enter for new line)",
elem_id="your_message"
)
clear = gr.Button("Clear")
dropdown = gr.Dropdown(
["Character 1", "Character 2", "Character 3", "Character 4", "Character 5", "Character 6", "Character 7", "Character 8", "Character 9", "Character 10", "Character 11", "Character 12", "Character 13"],
label="Characters",
info="Select the character that you'd like to speak to",
value="Character 1"
)
api_key = gr.Textbox(type="password", label="OpenAI API Key")
def user(user_message, history):
history.append([user_message, None])
return "", history, format_history(history)
def bot(history, character, api_key):
user_message = history[-1][0]
bot_message_generator = predict(user_message, history[:-1], character, api_key)
for chunk in bot_message_generator:
history[-1][1] = chunk
yield history, format_history(history)
msg.submit(user, [msg, chat_history], [msg, chat_history, chat_display]).then(
bot, [chat_history, dropdown, api_key], [chat_history, chat_display]
)
clear.click(lambda: ([], []), None, [chat_history, chat_display], queue=False)
dropdown.change(lambda x: ([], []), dropdown, [chat_history, chat_display])
demo.queue()
demo.launch(max_threads=20)