Iisakki Rotko
commited on
Commit
•
f228ce0
1
Parent(s):
1c7799f
feat: working with assistant, UI, and others
Browse files- icons/explore.svg +1 -0
- wanderlust.py +145 -62
icons/explore.svg
ADDED
wanderlust.py
CHANGED
@@ -2,7 +2,11 @@ import json
|
|
2 |
import os
|
3 |
|
4 |
import ipyleaflet
|
5 |
-
import
|
|
|
|
|
|
|
|
|
6 |
|
7 |
import solara
|
8 |
|
@@ -17,11 +21,11 @@ center = solara.reactive(center_default)
|
|
17 |
markers = solara.reactive([])
|
18 |
|
19 |
url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
|
20 |
-
openai
|
21 |
model = "gpt-4-1106-preview"
|
22 |
|
23 |
|
24 |
-
|
25 |
{
|
26 |
"type": "function",
|
27 |
"function": {
|
@@ -94,17 +98,15 @@ functions = {
|
|
94 |
|
95 |
|
96 |
def ai_call(tool_call):
|
97 |
-
function = tool_call
|
98 |
-
name = function
|
99 |
-
arguments = json.loads(function
|
100 |
return_value = functions[name](**arguments)
|
101 |
-
|
102 |
-
"
|
103 |
-
"
|
104 |
-
"name": tool_call["function"]["name"],
|
105 |
-
"content": return_value,
|
106 |
}
|
107 |
-
return
|
108 |
|
109 |
|
110 |
@solara.component
|
@@ -129,39 +131,66 @@ def Map():
|
|
129 |
@solara.component
|
130 |
def ChatInterface():
|
131 |
prompt = solara.use_reactive("")
|
|
|
|
|
|
|
|
|
132 |
|
133 |
def add_message(value: str):
|
134 |
if value == "":
|
135 |
return
|
136 |
-
messages.set(messages.value + [{"role": "user", "content": value}])
|
137 |
prompt.set("")
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
return
|
142 |
-
|
143 |
-
|
144 |
-
completion = openai.ChatCompletion.create(
|
145 |
-
model=model,
|
146 |
-
messages=messages.value,
|
147 |
-
# Add function calling
|
148 |
-
tools=function_descriptions,
|
149 |
-
tool_choice="auto",
|
150 |
-
)
|
151 |
-
|
152 |
-
output = completion.choices[0].message
|
153 |
-
print("received", output)
|
154 |
try:
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
def handle_message(message):
|
162 |
print("handle", message)
|
163 |
messages = []
|
164 |
-
if message
|
165 |
tools_calls = message.get("tool_calls", [])
|
166 |
for tool_call in tools_calls:
|
167 |
messages.append(ai_call(tool_call))
|
@@ -173,38 +202,71 @@ def ChatInterface():
|
|
173 |
handle_message(message)
|
174 |
|
175 |
solara.use_effect(handle_initial, [])
|
176 |
-
result = solara.use_thread(ask, dependencies=[messages.value])
|
177 |
with solara.Column(
|
178 |
-
style={
|
|
|
|
|
|
|
|
|
|
|
179 |
classes=["chat-interface"],
|
180 |
):
|
181 |
if len(messages.value) > 0:
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
else:
|
|
|
|
|
|
|
|
|
194 |
solara.Preformatted(
|
195 |
repr(message),
|
196 |
classes=["chat-message", "assistant-message"],
|
197 |
)
|
198 |
-
|
199 |
-
pass # no need to display
|
200 |
-
else:
|
201 |
-
solara.Preformatted(
|
202 |
-
repr(message), classes=["chat-message", "assistant-message"]
|
203 |
-
)
|
204 |
-
# solara.Text(message, classes=["chat-message"])
|
205 |
with solara.Column():
|
206 |
solara.InputText(
|
207 |
-
label="Ask your ",
|
208 |
value=prompt,
|
209 |
style={"flex-grow": "1"},
|
210 |
on_value=add_message,
|
@@ -234,26 +296,47 @@ def Page():
|
|
234 |
messages.set(json.load(f))
|
235 |
reset_ui()
|
236 |
|
237 |
-
with solara.Column(
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
ChatInterface().key(f"chat-{reset_counter}")
|
244 |
-
with solara.Column(style={"width": "
|
245 |
Map() # .key(f"map-{reset_counter}")
|
246 |
|
247 |
solara.Style(
|
248 |
"""
|
249 |
.jupyter-widgets.leaflet-widgets{
|
250 |
height: 100%;
|
|
|
251 |
}
|
252 |
.solara-autorouter-content{
|
253 |
display: flex;
|
254 |
flex-direction: column;
|
255 |
justify-content: stretch;
|
256 |
}
|
|
|
|
|
|
|
|
|
|
|
257 |
"""
|
258 |
)
|
259 |
|
|
|
2 |
import os
|
3 |
|
4 |
import ipyleaflet
|
5 |
+
from openai import OpenAI, NotFoundError
|
6 |
+
from openai.types.beta import Thread
|
7 |
+
from openai.types.beta.threads import Run
|
8 |
+
|
9 |
+
import time
|
10 |
|
11 |
import solara
|
12 |
|
|
|
21 |
markers = solara.reactive([])
|
22 |
|
23 |
url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
|
24 |
+
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
25 |
model = "gpt-4-1106-preview"
|
26 |
|
27 |
|
28 |
+
tools = [
|
29 |
{
|
30 |
"type": "function",
|
31 |
"function": {
|
|
|
98 |
|
99 |
|
100 |
def ai_call(tool_call):
|
101 |
+
function = tool_call.function
|
102 |
+
name = function.name
|
103 |
+
arguments = json.loads(function.arguments)
|
104 |
return_value = functions[name](**arguments)
|
105 |
+
tool_outputs = {
|
106 |
+
"tool_call_id": tool_call.id,
|
107 |
+
"output": return_value,
|
|
|
|
|
108 |
}
|
109 |
+
return tool_outputs
|
110 |
|
111 |
|
112 |
@solara.component
|
|
|
131 |
@solara.component
|
132 |
def ChatInterface():
|
133 |
prompt = solara.use_reactive("")
|
134 |
+
run_id: solara.Reactive[str] = solara.use_reactive(None)
|
135 |
+
|
136 |
+
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
|
137 |
+
print("thread id:", thread.id)
|
138 |
|
139 |
def add_message(value: str):
|
140 |
if value == "":
|
141 |
return
|
|
|
142 |
prompt.set("")
|
143 |
+
new_message = openai.beta.threads.messages.create(
|
144 |
+
thread_id=thread.id, content=value, role="user"
|
145 |
+
)
|
146 |
+
messages.set([*messages.value, new_message])
|
147 |
+
run_id.value = openai.beta.threads.runs.create(
|
148 |
+
thread_id=thread.id,
|
149 |
+
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
|
150 |
+
tools=tools,
|
151 |
+
).id
|
152 |
+
print("Run id:", run_id.value)
|
153 |
+
|
154 |
+
def poll():
|
155 |
+
if not run_id.value:
|
156 |
return
|
157 |
+
completed = False
|
158 |
+
while not completed:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
try:
|
160 |
+
run = openai.beta.threads.runs.retrieve(
|
161 |
+
run_id.value, thread_id=thread.id
|
162 |
+
) # When run is complete
|
163 |
+
print("run", run.status)
|
164 |
+
except NotFoundError:
|
165 |
+
print("run not found (Yet)")
|
166 |
+
continue
|
167 |
+
if run.status == "requires_action":
|
168 |
+
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
169 |
+
tool_output = ai_call(tool_call)
|
170 |
+
openai.beta.threads.runs.submit_tool_outputs(
|
171 |
+
thread_id=thread.id,
|
172 |
+
run_id=run_id.value,
|
173 |
+
tool_outputs=[tool_output],
|
174 |
+
)
|
175 |
+
if run.status == "completed":
|
176 |
+
messages.set(
|
177 |
+
[
|
178 |
+
*messages.value,
|
179 |
+
openai.beta.threads.messages.list(thread.id).data[0],
|
180 |
+
]
|
181 |
+
)
|
182 |
+
run_id.set(None)
|
183 |
+
completed = True
|
184 |
+
time.sleep(0.1)
|
185 |
+
retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
|
186 |
+
messages.set(retrieved_messages.data)
|
187 |
+
|
188 |
+
result = solara.use_thread(poll, dependencies=[run_id.value])
|
189 |
|
190 |
def handle_message(message):
|
191 |
print("handle", message)
|
192 |
messages = []
|
193 |
+
if message.role == "assistant":
|
194 |
tools_calls = message.get("tool_calls", [])
|
195 |
for tool_call in tools_calls:
|
196 |
messages.append(ai_call(tool_call))
|
|
|
202 |
handle_message(message)
|
203 |
|
204 |
solara.use_effect(handle_initial, [])
|
205 |
+
# result = solara.use_thread(ask, dependencies=[messages.value])
|
206 |
with solara.Column(
|
207 |
+
style={
|
208 |
+
"height": "100%",
|
209 |
+
"width": "38vw",
|
210 |
+
"justify-content": "center",
|
211 |
+
"background": "linear-gradient(0deg, transparent 75%, white 100%);",
|
212 |
+
},
|
213 |
classes=["chat-interface"],
|
214 |
):
|
215 |
if len(messages.value) > 0:
|
216 |
+
# The height works effectively as `min-height`, since flex will grow the container to fill the available space
|
217 |
+
with solara.Column(
|
218 |
+
style={
|
219 |
+
"flex-grow": "1",
|
220 |
+
"overflow-y": "auto",
|
221 |
+
"height": "100px",
|
222 |
+
"flex-direction": "column-reverse",
|
223 |
+
}
|
224 |
+
):
|
225 |
+
for message in reversed(messages.value):
|
226 |
+
with solara.Row(style={"align-items": "flex-start"}):
|
227 |
+
if message.role == "user":
|
228 |
+
solara.Text(
|
229 |
+
message.content[0].text.value,
|
230 |
+
classes=["chat-message", "user-message"],
|
231 |
+
)
|
232 |
+
assert len(message.content) == 1
|
233 |
+
elif message.role == "assistant":
|
234 |
+
if message.content[0].text.value:
|
235 |
+
solara.v.Icon(
|
236 |
+
children=["mdi-compass-outline"],
|
237 |
+
style_="padding-top: 10px;",
|
238 |
+
)
|
239 |
+
solara.Markdown(message.content[0].text.value)
|
240 |
+
elif message.content.tool_calls:
|
241 |
+
solara.v.Icon(
|
242 |
+
children=["mdi-map"],
|
243 |
+
style_="padding-top: 10px;",
|
244 |
+
)
|
245 |
+
solara.Markdown("*Calling map functions*")
|
246 |
+
else:
|
247 |
+
solara.v.Icon(
|
248 |
+
children=["mdi-compass-outline"],
|
249 |
+
style_="padding-top: 10px;",
|
250 |
+
)
|
251 |
+
solara.Preformatted(
|
252 |
+
repr(message),
|
253 |
+
classes=["chat-message", "assistant-message"],
|
254 |
+
)
|
255 |
+
elif message["role"] == "tool":
|
256 |
+
pass # no need to display
|
257 |
else:
|
258 |
+
solara.v.Icon(
|
259 |
+
children=["mdi-compass-outline"],
|
260 |
+
style_="padding-top: 10px;",
|
261 |
+
)
|
262 |
solara.Preformatted(
|
263 |
repr(message),
|
264 |
classes=["chat-message", "assistant-message"],
|
265 |
)
|
266 |
+
# solara.Text(message, classes=["chat-message"])
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
with solara.Column():
|
268 |
solara.InputText(
|
269 |
+
label="Ask your question here",
|
270 |
value=prompt,
|
271 |
style={"flex-grow": "1"},
|
272 |
on_value=add_message,
|
|
|
296 |
messages.set(json.load(f))
|
297 |
reset_ui()
|
298 |
|
299 |
+
with solara.Column(
|
300 |
+
style={
|
301 |
+
"height": "95vh",
|
302 |
+
"justify-content": "center",
|
303 |
+
"padding": "45px 50px 75px 50px",
|
304 |
+
},
|
305 |
+
gap="5vh",
|
306 |
+
):
|
307 |
+
with solara.Row(justify="space-between"):
|
308 |
+
with solara.Row(gap="10px", style={"align-items": "center"}):
|
309 |
+
solara.v.Icon(children=["mdi-compass-rose"], size="36px")
|
310 |
+
solara.HTML(
|
311 |
+
tag="h2",
|
312 |
+
unsafe_innerHTML="Wanderlust",
|
313 |
+
style={"display": "inline-block"},
|
314 |
+
)
|
315 |
+
# with solara.Row(gap="10px"):
|
316 |
+
# solara.Button("Save", on_click=save)
|
317 |
+
# solara.Button("Load", on_click=load)
|
318 |
+
# solara.Button("Soft reset", on_click=reset_ui)
|
319 |
+
with solara.Row(justify="space-between", style={"flex-grow": "1"}):
|
320 |
ChatInterface().key(f"chat-{reset_counter}")
|
321 |
+
with solara.Column(style={"width": "50vw", "justify-content": "center"}):
|
322 |
Map() # .key(f"map-{reset_counter}")
|
323 |
|
324 |
solara.Style(
|
325 |
"""
|
326 |
.jupyter-widgets.leaflet-widgets{
|
327 |
height: 100%;
|
328 |
+
border-radius: 20px;
|
329 |
}
|
330 |
.solara-autorouter-content{
|
331 |
display: flex;
|
332 |
flex-direction: column;
|
333 |
justify-content: stretch;
|
334 |
}
|
335 |
+
.v-toolbar__title{
|
336 |
+
display: flex;
|
337 |
+
align-items: center;
|
338 |
+
column-gap: 0.5rem;
|
339 |
+
}
|
340 |
"""
|
341 |
)
|
342 |
|