Upload 2 files
Browse files- app.py +136 -0
- topics.json +62 -0
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import spaces
|
8 |
+
|
9 |
+
API_URL = os.environ.get("API_URL", "default_api_url_if_not_set")
|
10 |
+
BEARER_TOKEN = os.environ.get("BEARER_TOKEN", "default_token_if_not_set")
|
11 |
+
headers = {
|
12 |
+
"Authorization": f"Bearer {BEARER_TOKEN}",
|
13 |
+
"Content-Type": "application/json"
|
14 |
+
}
|
15 |
+
|
16 |
+
# Define a function to load topics from a JSON file
|
17 |
+
def load_topics(filename):
|
18 |
+
try:
|
19 |
+
with open(filename, 'r') as file:
|
20 |
+
data = json.load(file)
|
21 |
+
return data
|
22 |
+
except FileNotFoundError:
|
23 |
+
print(f"Error: The file {filename} was not found.")
|
24 |
+
return {}
|
25 |
+
except json.JSONDecodeError:
|
26 |
+
print("Error: Failed to decode JSON.")
|
27 |
+
return {}
|
28 |
+
|
29 |
+
# Path to your JSON file
|
30 |
+
topics_json_path = 'topics.json'
|
31 |
+
|
32 |
+
# Call the function and store the topics
|
33 |
+
topics = load_topics(topics_json_path)
|
34 |
+
|
35 |
+
userdata = dict()
|
36 |
+
|
37 |
+
def query(payload):
|
38 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
39 |
+
json = response.json()
|
40 |
+
return json
|
41 |
+
|
42 |
+
|
43 |
+
@spaces.GPU
|
44 |
+
def generate(
|
45 |
+
message: str,
|
46 |
+
chat_history: list[tuple[str, str]],
|
47 |
+
system_prompt: str,
|
48 |
+
max_new_tokens: int = 1024,
|
49 |
+
temperature: float = 0.6,
|
50 |
+
top_p: float = 0.9,
|
51 |
+
top_k: int = 50,
|
52 |
+
repetition_penalty: float = 1.2,
|
53 |
+
user_id: uuid.UUID = uuid.uuid4()
|
54 |
+
) -> str:
|
55 |
+
user_id_str = user_id.hex
|
56 |
+
if user_id_str not in userdata or "topicId" not in userdata[user_id_str]:
|
57 |
+
userdata[user_id_str] = {"topicId": "0", "topic_flag": False}
|
58 |
+
topic = topics[userdata[user_id_str]["topicId"]]
|
59 |
+
result = query({
|
60 |
+
"inputs":"" ,
|
61 |
+
"message":message,
|
62 |
+
"chat_history":chat_history,
|
63 |
+
"system_prompt":system_prompt,
|
64 |
+
"instruction": topic["instruction"],
|
65 |
+
"conclusions": topic["conclusions"],
|
66 |
+
"context": topic["context"],
|
67 |
+
"max_new_tokens":max_new_tokens,
|
68 |
+
"temperature":temperature,
|
69 |
+
"top_p":top_p,
|
70 |
+
"top_k":top_k,
|
71 |
+
"repetition_penalty":repetition_penalty,
|
72 |
+
})
|
73 |
+
|
74 |
+
conclusion = result.get("conclusion")
|
75 |
+
if conclusion is not None:
|
76 |
+
next_topic_id = topic["conclusionAction"][conclusion]["next"]
|
77 |
+
extra = topic["conclusionAction"][conclusion]["extra"]
|
78 |
+
userdata[user_id_str]["topicId"] = next_topic_id
|
79 |
+
userdata[user_id_str]["topic_flag"] = True
|
80 |
+
return result.get("generated_text") + "\n" + extra + "\n" + topics[next_topic_id]["primer"]
|
81 |
+
|
82 |
+
return result.get("generated_text")
|
83 |
+
|
84 |
+
def update(chatbot_state):
|
85 |
+
# Check if the user_id exists in userdata, if not, create a default entry
|
86 |
+
if user_id.value.hex not in userdata:
|
87 |
+
userdata[user_id.value.hex] = {"topic_flag": False}
|
88 |
+
|
89 |
+
# Now you can safely get the topic_flag value
|
90 |
+
user_topic_flag = userdata[user_id.value.hex].get("topic_flag", False)
|
91 |
+
|
92 |
+
# If topic_flag is True, reset it to False and return the primer
|
93 |
+
if user_topic_flag:
|
94 |
+
userdata[user_id.value.hex]["topic_flag"] = False
|
95 |
+
return [[None, topics[userdata[user_id.value.hex]["topicId"]]["primer"]]]
|
96 |
+
|
97 |
+
# Return the original chatbot_state if topic_flag is not True
|
98 |
+
return chatbot_state
|
99 |
+
|
100 |
+
|
101 |
+
# Create Gradio interface components (inputs)
|
102 |
+
system_prompt_input = gr.Textbox(label="System prompt")
|
103 |
+
max_new_tokens_input = gr.Slider(minimum=1, maximum=2048, value=50, step=1, label="Max New Tokens")
|
104 |
+
temperature_input = gr.Slider(minimum=0.1, maximum=4.0, step=0.1, value=0.6, label="Temperature")
|
105 |
+
top_p_input = gr.Slider(minimum=0.05, maximum=1.0, step=0.05, value=0.9, label="Top-p")
|
106 |
+
top_k_input = gr.Slider(minimum=1, maximum=1000, step=1, value=50, label="Top-k")
|
107 |
+
repetition_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition Penalty")
|
108 |
+
user_id = gr.State(uuid.uuid4())
|
109 |
+
|
110 |
+
chat_interface = gr.ChatInterface(
|
111 |
+
fn=generate,
|
112 |
+
chatbot=gr.Chatbot([[None, topics["0"]["primer"]]]),
|
113 |
+
additional_inputs=[
|
114 |
+
system_prompt_input,
|
115 |
+
max_new_tokens_input,
|
116 |
+
temperature_input,
|
117 |
+
top_p_input,
|
118 |
+
top_k_input,
|
119 |
+
repetition_penalty_input,
|
120 |
+
user_id
|
121 |
+
],
|
122 |
+
stop_btn=gr.Button("Stop"),
|
123 |
+
examples=[
|
124 |
+
#
|
125 |
+
],
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
with gr.Blocks(css="style.css") as demo:
|
130 |
+
|
131 |
+
chat_interface.render()
|
132 |
+
chat_interface.submit_btn.click(update, inputs=chat_interface.chatbot_state, outputs=chat_interface.chatbot_state)
|
133 |
+
chat_interface.textbox.input(update, inputs=chat_interface.chatbot_state, outputs=chat_interface.chatbot_state)
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
demo.queue(max_size=20).launch(debug=True)
|
topics.json
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": {
|
3 |
+
"instruction": "Determine if the user is a parent or caregiver who has Autism concerns about their child. Do not ask follow up quesions, just clasify what they answer.",
|
4 |
+
"primer": "Are you a parent or gaurdian with Autism concerns about your child?",
|
5 |
+
"conclusions": [
|
6 |
+
[
|
7 |
+
"[YES]",
|
8 |
+
"The user is a caregiver or gaurdian with concerns"
|
9 |
+
],
|
10 |
+
[
|
11 |
+
"[NO]",
|
12 |
+
"The user is not a caregive or guardian with concerns"
|
13 |
+
]
|
14 |
+
],
|
15 |
+
"conclusionAction": {
|
16 |
+
"[YES]": {
|
17 |
+
"next":"1",
|
18 |
+
"extra":"Autism Spectrum Disorder (ASD) is a developmental disorder characterized by a range of signs and symptoms that can vary widely among individuals. Common indicators include challenges with social interaction, repetitive behaviors, and difficulty in communication. Some individuals may show signs early in infancy, while others may develop more typically and begin showing signs later. The pathway to diagnosis often starts with parental or caregiver concerns, followed by professional evaluations which may include developmental screenings and comprehensive diagnostic assessments by specialists. Once diagnosed, treatment plans are tailored to the individual's needs, often involving behavioral therapies, educational interventions, and sometimes medication. Following diagnosis, ongoing support and therapy are essential to help individuals with ASD maximize their potential in social, educational, and occupational settings."
|
19 |
+
},
|
20 |
+
"[NO]": {
|
21 |
+
"next":"2",
|
22 |
+
"extra":"This resource is intended to help with ASD screenings."
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"context": [],
|
26 |
+
"contextSource": []
|
27 |
+
},
|
28 |
+
"1": {
|
29 |
+
"instruction": "Determine if the user has had their child screened for ASD",
|
30 |
+
"primer": "Has your child been screened for ASD?",
|
31 |
+
"conclusions": [
|
32 |
+
[
|
33 |
+
"[YES]",
|
34 |
+
"The user's child has been screened for ASD"
|
35 |
+
],
|
36 |
+
[
|
37 |
+
"[NO]",
|
38 |
+
"The user's child has not been screened for ASD"
|
39 |
+
]
|
40 |
+
],
|
41 |
+
"conclusionAction": {
|
42 |
+
"[YES]": {
|
43 |
+
"next":"1",
|
44 |
+
"extra":""
|
45 |
+
},
|
46 |
+
"[NO]": {
|
47 |
+
"next":"2",
|
48 |
+
"extra":""
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"context": [],
|
52 |
+
"contextSource": []
|
53 |
+
},
|
54 |
+
"2": {
|
55 |
+
"instruction": "End the conversation",
|
56 |
+
"primer": "Have a good day!",
|
57 |
+
"conclusions": [],
|
58 |
+
"conclusionAction": {},
|
59 |
+
"context": [],
|
60 |
+
"contextSource": []
|
61 |
+
}
|
62 |
+
}
|