File size: 8,624 Bytes
2879792
6295354
2879792
 
6295354
6d82372
6295354
 
2879792
 
57ea903
6295354
2879792
6d82372
2879792
 
185eb2a
2879792
57ea903
 
 
 
185eb2a
 
 
 
 
 
 
 
6295354
2879792
 
6295354
2879792
6295354
2879792
 
6295354
185eb2a
 
2879792
185eb2a
 
 
2879792
 
6295354
402fd8a
 
 
6d82372
402fd8a
733bfe1
2879792
733bfe1
 
 
 
 
2879792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185eb2a
733bfe1
 
185eb2a
 
 
 
 
 
 
733bfe1
2879792
 
 
 
 
 
 
 
 
 
 
 
6295354
296923b
 
6d82372
296923b
 
6d82372
185eb2a
 
 
 
733bfe1
185eb2a
 
 
 
 
 
6295354
 
185eb2a
2879792
 
 
 
 
 
 
 
 
185eb2a
296923b
185eb2a
 
6295354
2879792
185eb2a
 
 
 
 
 
 
 
 
2879792
 
6d82372
2879792
 
 
 
 
 
6d82372
57ea903
 
 
 
 
 
 
 
 
 
 
 
185eb2a
 
 
 
 
2879792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d82372
 
 
 
 
 
 
 
185eb2a
 
296923b
 
 
185eb2a
2879792
185eb2a
05e4ccd
 
 
 
 
 
 
 
 
57ea903
 
 
05e4ccd
 
57ea903
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import json
from dotenv import load_dotenv

load_dotenv()

from queue import Empty, Queue
from threading import Thread
import os
import gradio as gr
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from js import get_window_url_params
from callback import QueueCallback
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from db import (
    User,
    Chat,
    create_user,
    get_client,
    get_user_by_username,
    add_chat_by_uid,
)

MODELS_NAMES = ["gpt-3.5-turbo", "gpt-4"]
DEFAULT_TEMPERATURE = 0.7

ChatHistory = List[str]

logging.basicConfig(
    format="[%(asctime)s %(levelname)s]: %(message)s", level=logging.INFO
)
# load redis client
client = get_client()
# load up our system prompt
system_message_prompt = SystemMessagePromptTemplate.from_template(
    Path("prompts/system.prompt").read_text()
)
# for the human, we will just inject the text
human_message_prompt_template = HumanMessagePromptTemplate.from_template("{text}")

with open("data/patients.json") as f:
    patiens = json.load(f)

patients_names = [el["name"] for el in patiens]


def message_handler(
    chat: Optional[ChatOpenAI],
    message: str,
    chatbot_messages: ChatHistory,
    messages: List[BaseMessage],
) -> Tuple[ChatOpenAI, str, ChatHistory, List[BaseMessage]]:
    if chat is None:
        # in the queue we will store our streamed tokens
        queue = Queue()
        # let's create our default chat
        chat = ChatOpenAI(
            model_name=MODELS_NAMES[0],
            temperature=DEFAULT_TEMPERATURE,
            streaming=True,
            callbacks=([QueueCallback(queue)]),
        )
    else:
        # hacky way to get the queue back
        queue = chat.callbacks[0].queue
    job_done = object()

    logging.info("asking question to GPT")
    # let's add the messages to our stuff
    messages.append(HumanMessage(content=message))
    chatbot_messages.append((message, ""))
    # this is a little wrapper we need cuz we have to add the job_done
    def task():
        chat(messages)
        queue.put(job_done)

    # now let's start a thread and run the generation inside it
    t = Thread(target=task)
    t.start()
    # this will hold the content as we generate
    content = ""
    # now, we read the next_token from queue and do what it has to be done
    while True:
        try:
            next_token = queue.get(True, timeout=1)
            if next_token is job_done:
                break
            content += next_token
            chatbot_messages[-1] = (message, content)
            yield chat, "", chatbot_messages, messages
        except Empty:
            continue
    # finally we can add our reply to messsages
    messages.append(AIMessage(content=content))
    logging.debug(f"reply = {content}")
    logging.info(f"Done!")
    return chat, "", chatbot_messages, messages


def on_clear_click() -> Tuple[str, List, List]:
    return "", [], []


def on_done_click(
    chatbot_messages: ChatHistory, patient: str, user: User
) -> Tuple[str, List, List]:
    logging.info(f"Saving chat for user={user}")
    add_chat_by_uid(
        client, Chat(patient=patient, messages=chatbot_messages), user["uid"]
    )
    return on_clear_click()


def on_apply_settings_click(model_name: str, temperature: float):
    logging.info(
        f"Applying settings: model_name={model_name}, temperature={temperature}"
    )
    chat = ChatOpenAI(
        model_name=model_name,
        temperature=temperature,
        streaming=True,
        callbacks=[QueueCallback(Queue())],
    )
    # don't forget to nuke our queue
    chat.callbacks[0].queue.empty()
    return chat, *on_clear_click()


def on_drop_down_change(selected_item, messages):
    index = patients_names.index(selected_item)
    patient = patiens[index]
    messages = [system_message_prompt.format(patient=patient)]
    print(f"You selected: {selected_item}", index)
    return patient, patient, [], messages


def on_demo_load(url_params):
    username = url_params.get("username", "test")
    logging.info(f"Getting user for username={username}")
    create_user(client, User(username=username, uid=None))
    user = get_user_by_username(client, username)
    logging.info(f"User {user}")
    print(f"got url_params: {url_params}")
    return user, f"Nice to see you {user['username']} 👋"


url_params = gr.JSON({}, visible=False, label="URL Params")
# some css why not, "borrowed" from https://huggingface.co/spaces/ysharma/Gradio-demo-streaming/blob/main/app.py
with gr.Blocks(
    css="""#col_container {width: 700px; margin-left: auto; margin-right: auto;}
                #chatbot {height: 400px; overflow: auto;}"""
) as demo:
    # here we keep our state so multiple user can use the app at the same time!
    messages = gr.State([system_message_prompt.format(patient=patiens[0])])
    # same thing for the chat, we want one chat per use so callbacks are unique I guess
    chat = gr.State(None)
    user = gr.State(None)
    patient = gr.State(patiens[0])
    # see here https://github.com/gradio-app/gradio/discussions/2949#discussioncomment-5278991
    url_params.render()

    with gr.Column(elem_id="col_container"):
        gr.Markdown("# Welcome to OscePal! 👨‍⚕️🧑‍⚕️")
        welcome_markdown = gr.Markdown("")

        demo.load(
            fn=on_demo_load,
            inputs=[url_params],
            outputs=[user, welcome_markdown],
            _js=get_window_url_params,
        )

        chatbot = gr.Chatbot()
        with gr.Column():
            message = gr.Textbox(label="chat input")
            message.submit(
                message_handler,
                [chat, message, chatbot, messages],
                [chat, message, chatbot, messages],
                queue=True,
            )
            submit = gr.Button("Send Message", variant="primary")
            submit.click(
                message_handler,
                [chat, message, chatbot, messages],
                [chat, message, chatbot, messages],
            )
            
        with gr.Row():
            with gr.Column():
                js = "(x) => confirm('Press a button!')"

                done = gr.Button("Done",  variant="stop")
                done.click(
                    on_done_click,
                    [chatbot, patient, user],
                    [message, chatbot, messages],
                )
            with gr.Accordion("Settings", open=False):
                model_name = gr.Dropdown(
                    choices=MODELS_NAMES, value=MODELS_NAMES[0], label="model"
                )
                temperature = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.7,
                    step=0.1,
                    label="temperature",
                    interactive=True,
                )
                apply_settings = gr.Button("Apply")
                apply_settings.click(
                    on_apply_settings_click,
                    [model_name, temperature],
                    [chat, message, chatbot, messages],
                )
        with gr.Column():
            patients_names = [el["name"] for el in patiens]
            dropdown = gr.Dropdown(
                choices=patients_names,
                value=patients_names[0],
                interactive=True,
                label="Patient",
            )

            patient_card = gr.JSON(patient.value, visible=True, label="Patient card")
            dropdown.change(
                fn=on_drop_down_change,
                inputs=[dropdown, messages],
                outputs=[patient_card, patient, chatbot, messages],
            )


    # app = FastAPI()
    # os.makedirs("static", exist_ok=True)
    # app.mount("/static", StaticFiles(directory="static"), name="static")
    # templates = Jinja2Templates(directory="templates")
    # @app.get("/", response_class=HTMLResponse)
    # async def home(request: Request):
    #     return templates.TemplateResponse(
    #         "home.html", {"request": request, "videos": []})

    demo.queue()

    # gradio_app = gr.routes.App.create_app(demo)
    # app.mount("/gradio", gradio_app)