Spaces:
Runtime error
Runtime error
import os | |
import datetime | |
from zoneinfo import ZoneInfo | |
from typing import Optional, Tuple, List | |
import asyncio | |
import logging | |
from copy import deepcopy | |
import json | |
import gradio as gr | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import ConversationChain | |
from langchain.memory import ConversationTokenBufferMemory | |
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler | |
from langchain.schema import BaseMessage | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s:%(message)s') | |
gradio_logger = logging.getLogger("gradio_app") | |
gradio_logger.setLevel(logging.INFO) | |
logging.getLogger("openai").setLevel(logging.DEBUG) | |
GPT_3_5_CONTEXT_LENGTH = 4096 | |
def make_template(): | |
knowledge_cutoff = "September 2021" | |
current_date = datetime.datetime.now(ZoneInfo("America/New_York")).strftime("%Y-%m-%d") | |
system_msg = f"You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}" | |
human_template = "{input}" | |
return ChatPromptTemplate.from_messages([ | |
SystemMessagePromptTemplate.from_template(system_msg), | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template(human_template) | |
]) | |
def reset_textbox(): | |
return gr.update(value="") | |
def auth(username, password): | |
return (username, password) in creds | |
async def respond( | |
inp: str, | |
state: Optional[Tuple[List, | |
ConversationTokenBufferMemory, | |
ConversationChain]], | |
request: gr.Request | |
): | |
"""Execute the chat functionality.""" | |
def prep_messages(user_msg: str, memory_buffer: List[BaseMessage]) -> Tuple[str, List[BaseMessage]]: | |
messages_to_send = template.format_messages(input=user_msg, history=memory_buffer) | |
user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
_, encoding = llm._get_encoding_model() | |
while user_msg_token_count > GPT_3_5_CONTEXT_LENGTH: | |
gradio_logger.warning(f"Pruning user message due to user message token length of {user_msg_token_count}") | |
user_msg = encoding.decode(llm.get_token_ids(user_msg)[:GPT_3_5_CONTEXT_LENGTH - 100]) | |
messages_to_send = template.format_messages(input=user_msg, history=memory_buffer) | |
user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
while total_token_count > GPT_3_5_CONTEXT_LENGTH: | |
gradio_logger.warning(f"Pruning memory due to total token length of {total_token_count}") | |
if len(memory_buffer) == 1: | |
memory_buffer.pop(0) | |
continue | |
memory_buffer = memory_buffer[1:] | |
messages_to_send = template.format_messages(input=user_msg, history=memory_buffer) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
return user_msg, memory_buffer | |
try: | |
if state is None: | |
memory = ConversationTokenBufferMemory( | |
llm=llm, | |
max_token_limit=GPT_3_5_CONTEXT_LENGTH, | |
return_messages=True) | |
chain = ConversationChain(memory=memory, prompt=template, llm=llm) | |
state = ([], memory, chain) | |
history, memory, chain = state | |
gradio_logger.info(f"""[{request.username}] STARTING CHAIN""") | |
gradio_logger.debug(f"History: {history}") | |
gradio_logger.debug(f"User input: {inp}") | |
inp, memory.chat_memory.messages = prep_messages(inp, memory.buffer) | |
messages_to_send = template.format_messages(input=inp, history=memory.buffer) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
gradio_logger.debug(f"Messages to send: {messages_to_send}") | |
gradio_logger.info(f"Tokens to send: {total_token_count}") | |
# Run chain and append input. | |
callback = AsyncIteratorCallbackHandler() | |
run = asyncio.create_task(chain.apredict( | |
input=inp, callbacks=[callback])) | |
history.append((inp, "")) | |
async for tok in callback.aiter(): | |
user, bot = history[-1] | |
bot += tok | |
history[-1] = (user, bot) | |
yield history, (history, memory, chain) | |
await run | |
gradio_logger.info(f"""[{request.username}] ENDING CHAIN""") | |
gradio_logger.debug(f"History: {history}") | |
gradio_logger.debug(f"Memory: {memory.json()}") | |
data_to_flag = { | |
"history": deepcopy(history), | |
"username": request.username | |
}, | |
gradio_logger.debug(f"Data to flag: {data_to_flag}") | |
gradio_flagger.flag(flag_data=data_to_flag, username=request.username) | |
except Exception as e: | |
gradio_logger.exception(e) | |
raise e | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", | |
temperature=1, | |
openai_api_key=OPENAI_API_KEY, | |
max_retries=6, | |
request_timeout=100, | |
streaming=True) | |
template = make_template() | |
theme = gr.themes.Soft() | |
creds = [(os.getenv("USERNAME"), os.getenv("PASSWORD"))] | |
gradio_flagger = gr.CSVLogger() | |
title = "Chat with ChatGPT" | |
with gr.Blocks(css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""", | |
theme=theme, | |
analytics_enabled=False, | |
title=title) as demo: | |
gr.HTML(title) | |
with gr.Column(elem_id="col_container"): | |
state = gr.State() | |
chatbot = gr.Chatbot(label='ChatBot', elem_id="chatbot") | |
inputs = gr.Textbox(placeholder="Send a message.", | |
label="Type an input and press Enter") | |
b1 = gr.Button(value="Submit", variant="secondary").style( | |
full_width=False) | |
gradio_flagger.setup([chatbot], "flagged_data_points") | |
inputs.submit(respond, [inputs, state], [chatbot, state],) | |
b1.click(respond, [inputs, state], [chatbot, state],) | |
b1.click(reset_textbox, [], [inputs]) | |
inputs.submit(reset_textbox, [], [inputs]) | |
demo.queue( | |
max_size=99, | |
concurrency_count=20, | |
api_open=False).launch( | |
debug=True, | |
auth=auth) | |