import os import datetime from zoneinfo import ZoneInfo from typing import Optional, Tuple, List import asyncio import logging from copy import deepcopy import json import gradio as gr from langchain.chat_models import ChatOpenAI from langchain.chains import ConversationChain from langchain.memory import ConversationTokenBufferMemory from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.schema import BaseMessage from langchain.prompts.chat import ( ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s:%(message)s') gradio_logger = logging.getLogger("gradio_app") gradio_logger.setLevel(logging.INFO) logging.getLogger("openai").setLevel(logging.DEBUG) GPT_3_5_CONTEXT_LENGTH = 4096 def make_template(): knowledge_cutoff = "September 2021" current_date = datetime.datetime.now(ZoneInfo("America/New_York")).strftime("%Y-%m-%d") system_msg = f"You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}" human_template = "{input}" return ChatPromptTemplate.from_messages([ SystemMessagePromptTemplate.from_template(system_msg), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template(human_template) ]) def reset_textbox(): return gr.update(value="") def auth(username, password): return (username, password) in creds async def respond( inp: str, state: Optional[Tuple[List, ConversationTokenBufferMemory, ConversationChain]], request: gr.Request ): """Execute the chat functionality.""" def prep_messages(user_msg: str, memory_buffer: List[BaseMessage]) -> Tuple[str, List[BaseMessage]]: messages_to_send = template.format_messages(input=user_msg, history=memory_buffer) user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) _, encoding = llm._get_encoding_model() while user_msg_token_count > GPT_3_5_CONTEXT_LENGTH: gradio_logger.warning(f"Pruning user message due to user message token length of {user_msg_token_count}") user_msg = encoding.decode(llm.get_token_ids(user_msg)[:GPT_3_5_CONTEXT_LENGTH - 100]) messages_to_send = template.format_messages(input=user_msg, history=memory_buffer) user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) while total_token_count > GPT_3_5_CONTEXT_LENGTH: gradio_logger.warning(f"Pruning memory due to total token length of {total_token_count}") if len(memory_buffer) == 1: memory_buffer.pop(0) continue memory_buffer = memory_buffer[1:] messages_to_send = template.format_messages(input=user_msg, history=memory_buffer) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) return user_msg, memory_buffer try: if state is None: memory = ConversationTokenBufferMemory( llm=llm, max_token_limit=GPT_3_5_CONTEXT_LENGTH, return_messages=True) chain = ConversationChain(memory=memory, prompt=template, llm=llm) state = ([], memory, chain) history, memory, chain = state gradio_logger.info(f"""[{request.username}] STARTING CHAIN""") gradio_logger.debug(f"History: {history}") gradio_logger.debug(f"User input: {inp}") inp, memory.chat_memory.messages = prep_messages(inp, memory.buffer) messages_to_send = template.format_messages(input=inp, history=memory.buffer) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) gradio_logger.debug(f"Messages to send: {messages_to_send}") gradio_logger.info(f"Tokens to send: {total_token_count}") # Run chain and append input. callback = AsyncIteratorCallbackHandler() run = asyncio.create_task(chain.apredict( input=inp, callbacks=[callback])) history.append((inp, "")) async for tok in callback.aiter(): user, bot = history[-1] bot += tok history[-1] = (user, bot) yield history, (history, memory, chain) await run gradio_logger.info(f"""[{request.username}] ENDING CHAIN""") gradio_logger.debug(f"History: {history}") gradio_logger.debug(f"Memory: {memory.json()}") data_to_flag = { "history": deepcopy(history), "username": request.username }, gradio_logger.debug(f"Data to flag: {data_to_flag}") gradio_flagger.flag(flag_data=data_to_flag, username=request.username) except Exception as e: gradio_logger.exception(e) raise e OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=1, openai_api_key=OPENAI_API_KEY, max_retries=6, request_timeout=100, streaming=True) template = make_template() theme = gr.themes.Soft() creds = [(os.getenv("USERNAME"), os.getenv("PASSWORD"))] gradio_flagger = gr.CSVLogger() title = "Chat with ChatGPT" with gr.Blocks(css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""", theme=theme, analytics_enabled=False, title=title) as demo: gr.HTML(title) with gr.Column(elem_id="col_container"): state = gr.State() chatbot = gr.Chatbot(label='ChatBot', elem_id="chatbot") inputs = gr.Textbox(placeholder="Send a message.", label="Type an input and press Enter") b1 = gr.Button(value="Submit", variant="secondary").style( full_width=False) gradio_flagger.setup([chatbot], "flagged_data_points") inputs.submit(respond, [inputs, state], [chatbot, state],) b1.click(respond, [inputs, state], [chatbot, state],) b1.click(reset_textbox, [], [inputs]) inputs.submit(reset_textbox, [], [inputs]) demo.queue( max_size=99, concurrency_count=20, api_open=False).launch( debug=True, auth=auth)