codys12 commited on
Commit
e24c49a
·
1 Parent(s): 8c7e775

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +136 -0
  2. topics.json +62 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import uuid
5
+
6
+ import gradio as gr
7
+ import spaces
8
+
9
+ API_URL = os.environ.get("API_URL", "default_api_url_if_not_set")
10
+ BEARER_TOKEN = os.environ.get("BEARER_TOKEN", "default_token_if_not_set")
11
+ headers = {
12
+ "Authorization": f"Bearer {BEARER_TOKEN}",
13
+ "Content-Type": "application/json"
14
+ }
15
+
16
+ # Define a function to load topics from a JSON file
17
+ def load_topics(filename):
18
+ try:
19
+ with open(filename, 'r') as file:
20
+ data = json.load(file)
21
+ return data
22
+ except FileNotFoundError:
23
+ print(f"Error: The file {filename} was not found.")
24
+ return {}
25
+ except json.JSONDecodeError:
26
+ print("Error: Failed to decode JSON.")
27
+ return {}
28
+
29
+ # Path to your JSON file
30
+ topics_json_path = 'topics.json'
31
+
32
+ # Call the function and store the topics
33
+ topics = load_topics(topics_json_path)
34
+
35
+ userdata = dict()
36
+
37
+ def query(payload):
38
+ response = requests.post(API_URL, headers=headers, json=payload)
39
+ json = response.json()
40
+ return json
41
+
42
+
43
+ @spaces.GPU
44
+ def generate(
45
+ message: str,
46
+ chat_history: list[tuple[str, str]],
47
+ system_prompt: str,
48
+ max_new_tokens: int = 1024,
49
+ temperature: float = 0.6,
50
+ top_p: float = 0.9,
51
+ top_k: int = 50,
52
+ repetition_penalty: float = 1.2,
53
+ user_id: uuid.UUID = uuid.uuid4()
54
+ ) -> str:
55
+ user_id_str = user_id.hex
56
+ if user_id_str not in userdata or "topicId" not in userdata[user_id_str]:
57
+ userdata[user_id_str] = {"topicId": "0", "topic_flag": False}
58
+ topic = topics[userdata[user_id_str]["topicId"]]
59
+ result = query({
60
+ "inputs":"" ,
61
+ "message":message,
62
+ "chat_history":chat_history,
63
+ "system_prompt":system_prompt,
64
+ "instruction": topic["instruction"],
65
+ "conclusions": topic["conclusions"],
66
+ "context": topic["context"],
67
+ "max_new_tokens":max_new_tokens,
68
+ "temperature":temperature,
69
+ "top_p":top_p,
70
+ "top_k":top_k,
71
+ "repetition_penalty":repetition_penalty,
72
+ })
73
+
74
+ conclusion = result.get("conclusion")
75
+ if conclusion is not None:
76
+ next_topic_id = topic["conclusionAction"][conclusion]["next"]
77
+ extra = topic["conclusionAction"][conclusion]["extra"]
78
+ userdata[user_id_str]["topicId"] = next_topic_id
79
+ userdata[user_id_str]["topic_flag"] = True
80
+ return result.get("generated_text") + "\n" + extra + "\n" + topics[next_topic_id]["primer"]
81
+
82
+ return result.get("generated_text")
83
+
84
+ def update(chatbot_state):
85
+ # Check if the user_id exists in userdata, if not, create a default entry
86
+ if user_id.value.hex not in userdata:
87
+ userdata[user_id.value.hex] = {"topic_flag": False}
88
+
89
+ # Now you can safely get the topic_flag value
90
+ user_topic_flag = userdata[user_id.value.hex].get("topic_flag", False)
91
+
92
+ # If topic_flag is True, reset it to False and return the primer
93
+ if user_topic_flag:
94
+ userdata[user_id.value.hex]["topic_flag"] = False
95
+ return [[None, topics[userdata[user_id.value.hex]["topicId"]]["primer"]]]
96
+
97
+ # Return the original chatbot_state if topic_flag is not True
98
+ return chatbot_state
99
+
100
+
101
+ # Create Gradio interface components (inputs)
102
+ system_prompt_input = gr.Textbox(label="System prompt")
103
+ max_new_tokens_input = gr.Slider(minimum=1, maximum=2048, value=50, step=1, label="Max New Tokens")
104
+ temperature_input = gr.Slider(minimum=0.1, maximum=4.0, step=0.1, value=0.6, label="Temperature")
105
+ top_p_input = gr.Slider(minimum=0.05, maximum=1.0, step=0.05, value=0.9, label="Top-p")
106
+ top_k_input = gr.Slider(minimum=1, maximum=1000, step=1, value=50, label="Top-k")
107
+ repetition_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition Penalty")
108
+ user_id = gr.State(uuid.uuid4())
109
+
110
+ chat_interface = gr.ChatInterface(
111
+ fn=generate,
112
+ chatbot=gr.Chatbot([[None, topics["0"]["primer"]]]),
113
+ additional_inputs=[
114
+ system_prompt_input,
115
+ max_new_tokens_input,
116
+ temperature_input,
117
+ top_p_input,
118
+ top_k_input,
119
+ repetition_penalty_input,
120
+ user_id
121
+ ],
122
+ stop_btn=gr.Button("Stop"),
123
+ examples=[
124
+ #
125
+ ],
126
+ )
127
+
128
+
129
+ with gr.Blocks(css="style.css") as demo:
130
+
131
+ chat_interface.render()
132
+ chat_interface.submit_btn.click(update, inputs=chat_interface.chatbot_state, outputs=chat_interface.chatbot_state)
133
+ chat_interface.textbox.input(update, inputs=chat_interface.chatbot_state, outputs=chat_interface.chatbot_state)
134
+
135
+ if __name__ == "__main__":
136
+ demo.queue(max_size=20).launch(debug=True)
topics.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": {
3
+ "instruction": "Determine if the user is a parent or caregiver who has Autism concerns about their child. Do not ask follow up quesions, just clasify what they answer.",
4
+ "primer": "Are you a parent or gaurdian with Autism concerns about your child?",
5
+ "conclusions": [
6
+ [
7
+ "[YES]",
8
+ "The user is a caregiver or gaurdian with concerns"
9
+ ],
10
+ [
11
+ "[NO]",
12
+ "The user is not a caregive or guardian with concerns"
13
+ ]
14
+ ],
15
+ "conclusionAction": {
16
+ "[YES]": {
17
+ "next":"1",
18
+ "extra":"Autism Spectrum Disorder (ASD) is a developmental disorder characterized by a range of signs and symptoms that can vary widely among individuals. Common indicators include challenges with social interaction, repetitive behaviors, and difficulty in communication. Some individuals may show signs early in infancy, while others may develop more typically and begin showing signs later. The pathway to diagnosis often starts with parental or caregiver concerns, followed by professional evaluations which may include developmental screenings and comprehensive diagnostic assessments by specialists. Once diagnosed, treatment plans are tailored to the individual's needs, often involving behavioral therapies, educational interventions, and sometimes medication. Following diagnosis, ongoing support and therapy are essential to help individuals with ASD maximize their potential in social, educational, and occupational settings."
19
+ },
20
+ "[NO]": {
21
+ "next":"2",
22
+ "extra":"This resource is intended to help with ASD screenings."
23
+ }
24
+ },
25
+ "context": [],
26
+ "contextSource": []
27
+ },
28
+ "1": {
29
+ "instruction": "Determine if the user has had their child screened for ASD",
30
+ "primer": "Has your child been screened for ASD?",
31
+ "conclusions": [
32
+ [
33
+ "[YES]",
34
+ "The user's child has been screened for ASD"
35
+ ],
36
+ [
37
+ "[NO]",
38
+ "The user's child has not been screened for ASD"
39
+ ]
40
+ ],
41
+ "conclusionAction": {
42
+ "[YES]": {
43
+ "next":"1",
44
+ "extra":""
45
+ },
46
+ "[NO]": {
47
+ "next":"2",
48
+ "extra":""
49
+ }
50
+ },
51
+ "context": [],
52
+ "contextSource": []
53
+ },
54
+ "2": {
55
+ "instruction": "End the conversation",
56
+ "primer": "Have a good day!",
57
+ "conclusions": [],
58
+ "conclusionAction": {},
59
+ "context": [],
60
+ "contextSource": []
61
+ }
62
+ }