Spaces:
Runtime error
Runtime error
import logging | |
from pathlib import Path | |
from typing import List, Optional, Tuple, Dict | |
import json | |
from dotenv import load_dotenv | |
load_dotenv() | |
from queue import Empty, Queue | |
from threading import Thread | |
import os | |
import gradio as gr | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate | |
from langchain.schema import AIMessage, BaseMessage, HumanMessage | |
from js import get_window_url_params | |
from callback import QueueCallback | |
from fastapi import FastAPI, File, UploadFile, Request | |
from fastapi.responses import HTMLResponse, RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from db import ( | |
User, | |
Chat, | |
create_user, | |
get_client, | |
get_user_by_username, | |
add_chat_by_uid, | |
) | |
MODELS_NAMES = ["gpt-3.5-turbo", "gpt-4"] | |
DEFAULT_TEMPERATURE = 0.7 | |
ChatHistory = List[str] | |
logging.basicConfig( | |
format="[%(asctime)s %(levelname)s]: %(message)s", level=logging.INFO | |
) | |
# load redis client | |
client = get_client() | |
# load up our system prompt | |
system_message_prompt = SystemMessagePromptTemplate.from_template( | |
Path("prompts/system.prompt").read_text() | |
) | |
# for the human, we will just inject the text | |
human_message_prompt_template = HumanMessagePromptTemplate.from_template("{text}") | |
with open("data/patients.json") as f: | |
patiens = json.load(f) | |
patients_names = [el["name"] for el in patiens] | |
def message_handler( | |
chat: Optional[ChatOpenAI], | |
message: str, | |
chatbot_messages: ChatHistory, | |
messages: List[BaseMessage], | |
) -> Tuple[ChatOpenAI, str, ChatHistory, List[BaseMessage]]: | |
if chat is None: | |
# in the queue we will store our streamed tokens | |
queue = Queue() | |
# let's create our default chat | |
chat = ChatOpenAI( | |
model_name=MODELS_NAMES[0], | |
temperature=DEFAULT_TEMPERATURE, | |
streaming=True, | |
callbacks=([QueueCallback(queue)]), | |
) | |
else: | |
# hacky way to get the queue back | |
queue = chat.callbacks[0].queue | |
job_done = object() | |
logging.info("asking question to GPT") | |
# let's add the messages to our stuff | |
messages.append(HumanMessage(content=f"Doctor:{message}")) | |
chatbot_messages.append((message, "")) | |
# this is a little wrapper we need cuz we have to add the job_done | |
def task(): | |
chat(messages) | |
queue.put(job_done) | |
# now let's start a thread and run the generation inside it | |
t = Thread(target=task) | |
t.start() | |
# this will hold the content as we generate | |
content = "" | |
# now, we read the next_token from queue and do what it has to be done | |
while True: | |
try: | |
next_token = queue.get(True, timeout=1) | |
if next_token is job_done: | |
break | |
content += next_token | |
chatbot_messages[-1] = (message, content) | |
yield chat, "", chatbot_messages, messages | |
except Empty: | |
continue | |
# finally we can add our reply to messsages | |
messages.append(AIMessage(content=content)) | |
logging.debug(f"reply = {content}") | |
logging.info(f"Done!") | |
return chat, "", chatbot_messages, messages | |
def on_clear_click() -> Tuple[str, List, List]: | |
return "", [], [] | |
def on_done_click( | |
chatbot_messages: ChatHistory, patient: str, user: User | |
) -> Tuple[str, List, List]: | |
logging.info(f"Saving chat for user={user}") | |
add_chat_by_uid( | |
client, Chat(patient=patient, messages=chatbot_messages), user["uid"] | |
) | |
return on_clear_click() | |
def on_apply_settings_click(model_name: str, temperature: float): | |
logging.info( | |
f"Applying settings: model_name={model_name}, temperature={temperature}" | |
) | |
chat = ChatOpenAI( | |
model_name=model_name, | |
temperature=temperature, | |
streaming=True, | |
callbacks=[QueueCallback(Queue())], | |
) | |
# don't forget to nuke our queue | |
chat.callbacks[0].queue.empty() | |
return chat, *on_clear_click() | |
def on_drop_down_change(selected_item, messages): | |
index = patients_names.index(selected_item) | |
patient = patiens[index] | |
messages = [system_message_prompt.format(patient=patient)] | |
print(f"You selected: {selected_item}", index) | |
return patient, patient, [], messages | |
def on_demo_load(url_params, request: gr.Request): | |
username = request.username or url_params.get("username", "test") | |
logging.info(f"Getting user for username={username}") | |
create_user(client, User(username=username, uid=None)) | |
user = get_user_by_username(client, username) | |
logging.info(f"User {user}") | |
print(f"got url_params: {url_params}") | |
return user, f"Nice to see you {user['username']} π" | |
url_params = gr.JSON({}, visible=False, label="URL Params") | |
# some css why not, "borrowed" from https://huggingface.co/spaces/ysharma/Gradio-demo-streaming/blob/main/app.py | |
with gr.Blocks( | |
css="""#col_container {width: 700px; margin-left: auto; margin-right: auto;} | |
#chatbot {height: 400px; overflow: auto;}""" | |
) as demo: | |
# here we keep our state so multiple user can use the app at the same time! | |
messages = gr.State([system_message_prompt.format(patient=patiens[0])]) | |
# same thing for the chat, we want one chat per use so callbacks are unique I guess | |
chat = gr.State(None) | |
user = gr.State(None) | |
patient = gr.State(patiens[0]) | |
# see here https://github.com/gradio-app/gradio/discussions/2949#discussioncomment-5278991 | |
url_params.render() | |
with gr.Column(elem_id="col_container"): | |
gr.Markdown("# Welcome to OscePal! π¨ββοΈπ§ββοΈ") | |
welcome_markdown = gr.Markdown("") | |
demo.load( | |
fn=on_demo_load, | |
inputs=[url_params], | |
outputs=[user, welcome_markdown], | |
_js=get_window_url_params, | |
) | |
chatbot = gr.Chatbot() | |
with gr.Column(): | |
message = gr.Textbox(label="chat input") | |
message.submit( | |
message_handler, | |
[chat, message, chatbot, messages], | |
[chat, message, chatbot, messages], | |
queue=True, | |
) | |
submit = gr.Button("Send Message", variant="primary") | |
submit.click( | |
message_handler, | |
[chat, message, chatbot, messages], | |
[chat, message, chatbot, messages], | |
) | |
with gr.Row(): | |
with gr.Column(): | |
js = "(x) => confirm('Press a button!')" | |
done = gr.Button("Done", variant="stop") | |
done.click( | |
on_done_click, | |
[chatbot, patient, user], | |
[message, chatbot, messages], | |
) | |
with gr.Accordion("Settings", open=False): | |
model_name = gr.Dropdown( | |
choices=MODELS_NAMES, value=MODELS_NAMES[0], label="model" | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="temperature", | |
interactive=True, | |
) | |
apply_settings = gr.Button("Apply") | |
apply_settings.click( | |
on_apply_settings_click, | |
[model_name, temperature], | |
[chat, message, chatbot, messages], | |
) | |
with gr.Column(): | |
patients_names = [el["name"] for el in patiens] | |
dropdown = gr.Dropdown( | |
choices=patients_names, | |
value=patients_names[0], | |
interactive=True, | |
label="Patient", | |
) | |
patient_card = gr.JSON(patient.value, visible=True, label="Patient card") | |
dropdown.change( | |
fn=on_drop_down_change, | |
inputs=[dropdown, messages], | |
outputs=[patient_card, patient, chatbot, messages], | |
) | |
# app = FastAPI() | |
# os.makedirs("static", exist_ok=True) | |
# app.mount("/static", StaticFiles(directory="static"), name="static") | |
# templates = Jinja2Templates(directory="templates") | |
# @app.get("/", response_class=HTMLResponse) | |
# async def home(request: Request): | |
# return templates.TemplateResponse( | |
# "home.html", {"request": request, "videos": []}) | |
def auth_handler(username: str, password: str) -> bool: | |
if password != os.environ["GRADIO_PASSWORD"]: | |
return False | |
return True | |
demo.queue() | |
demo.launch(auth=auth_handler) | |
# gradio_app = gr.routes.App.create_app(demo) | |
# app.mount("/gradio", gradio_app) | |