Spaces:
Running
Running
import re | |
import os | |
import panel as pn | |
from io import StringIO | |
from panel.io.mime_render import exec_with_return | |
from llama_index import ( | |
VectorStoreIndex, | |
SimpleDirectoryReader, | |
ServiceContext, | |
StorageContext, | |
load_index_from_storage, | |
) | |
from llama_index.chat_engine import ContextChatEngine | |
from llama_index.embeddings import OpenAIEmbedding | |
from llama_index.llms import OpenAI | |
SYSTEM_PROMPT = ( | |
"You are a data visualization pro and expert in HoloViz hvplot + holoviews. " | |
"Your primary goal is to assist the user in editing based on user requests using best practices. " | |
"Simply provide code in code fences (```python). You must have `hvplot_obj` as the last line of code. " | |
"Note, data columns are ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'] and " | |
"hvplot is built on top of holoviews--anything you can do with holoviews, you can do " | |
"with hvplot. First try to use hvplot **kwargs instead of opts, e.g. `legend='top_right'` " | |
"instead of `opts(legend_position='top_right')`. If you need to use opts, you can use " | |
"concise version, e.g. `opts(xlabel='Petal Length')` vs `opts(hv.Opts(xlabel='Petal Length'))`" | |
) | |
USER_CONTENT_FORMAT = """ | |
Request: | |
{content} | |
Code: | |
```python | |
{code} | |
``` | |
""".strip() | |
DEFAULT_HVPLOT = """ | |
import hvplot.pandas | |
from bokeh.sampledata.iris import flowers | |
hvplot_obj = flowers.hvplot(x='petal_length', y='petal_width', by='species', kind='scatter') | |
hvplot_obj | |
""".strip() | |
def exception_handler(exc): | |
if retries.value == 0: | |
chat_interface.send(f"Can't figure this out: {exc}", respond=False) | |
return | |
chat_interface.send(f"Fix this error:\n```python\n{exc}\n```") | |
retries.value = retries.value - 1 | |
def init_llm(event): | |
api_key = event.new | |
if not api_key: | |
api_key = os.environ.get("OPENAI_API_KEY") | |
if not api_key: | |
return | |
pn.state.cache["llm"] = OpenAI(api_key=api_key) | |
def create_chat_engine(llm): | |
try: | |
storage_context = StorageContext.from_defaults(persist_dir="persisted/") | |
index = load_index_from_storage(storage_context=storage_context) | |
except Exception as exc: | |
embed_model = OpenAIEmbedding() | |
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model) | |
documents = SimpleDirectoryReader( | |
input_dir="hvplot_docs", required_exts=[".md"], recursive=True | |
).load_data() | |
index = VectorStoreIndex.from_documents( | |
documents, service_context=service_context, show_progress=True | |
) | |
index.storage_context.persist("persisted/") | |
retriever = index.as_retriever() | |
chat_engine = ContextChatEngine.from_defaults( | |
system_prompt=SYSTEM_PROMPT, | |
retriever=retriever, | |
verbose=True, | |
) | |
return chat_engine | |
def callback(content: str, user: str, instance: pn.chat.ChatInterface): | |
if "llm" not in pn.state.cache: | |
yield "Need to set OpenAI API key first" | |
return | |
if "engine" not in pn.state.cache: | |
engine = pn.state.cache["engine"] = create_chat_engine(pn.state.cache["llm"]) | |
else: | |
engine = pn.state.cache["engine"] | |
# new user contents | |
user_content = USER_CONTENT_FORMAT.format( | |
content=content, code=code_editor.value | |
) | |
# send user content to chat engine | |
agent_response = engine.stream_chat(user_content) | |
message = None | |
for chunk in agent_response.response_gen: | |
message = instance.stream(chunk, message=message, user="OpenAI") | |
# extract code | |
llm_matches = re.findall(r"```python\n(.*)\n```", message.object, re.DOTALL) | |
if llm_matches: | |
llm_code = llm_matches[0] | |
if llm_code.splitlines()[-1].strip() != "hvplot_obj": | |
llm_code += "\nhvplot_obj" | |
code_editor.value = llm_code | |
retries.value = 2 | |
def update_plot(event): | |
with StringIO() as buf: | |
hvplot_pane.object = exec_with_return(event.new, stderr=buf) | |
buf.seek(0) | |
errors = buf.read() | |
if errors: | |
exception_handler(errors) | |
pn.extension("codeeditor", sizing_mode="stretch_width", exception_handler=exception_handler) | |
# instantiate widgets and panes | |
api_key_input = pn.widgets.PasswordInput( | |
placeholder=( | |
"Currently subsidized by Andrew, " | |
"but you can also pass your own OpenAI API Key" | |
) | |
) | |
chat_interface = pn.chat.ChatInterface( | |
callback=callback, | |
show_clear=False, | |
show_undo=False, | |
show_button_name=False, | |
message_params=dict( | |
show_reaction_icons=False, | |
show_copy_icon=False, | |
), | |
height=650, | |
callback_exception="verbose", | |
) | |
hvplot_pane = pn.pane.HoloViews( | |
exec_with_return(DEFAULT_HVPLOT), | |
sizing_mode="stretch_both", | |
) | |
code_editor = pn.widgets.CodeEditor( | |
value=DEFAULT_HVPLOT, | |
language="python", | |
sizing_mode="stretch_both", | |
) | |
retries = pn.widgets.IntInput(value=2, visible=False) | |
error = pn.widgets.StaticText(visible=False) | |
# watch for code changes | |
api_key_input.param.watch(init_llm, "value") | |
code_editor.param.watch(update_plot, "value") | |
api_key_input.param.trigger("value") | |
# lay them out | |
tabs = pn.Tabs( | |
("Plot", hvplot_pane), | |
("Code", code_editor), | |
) | |
sidebar = [api_key_input, chat_interface] | |
main = [tabs] | |
template = pn.template.FastListTemplate( | |
sidebar=sidebar, | |
main=main, | |
sidebar_width=600, | |
main_layout=None, | |
accent_base_color="#fd7000", | |
header_background="#fd7000", | |
title="Chat with Plot" | |
) | |
template.servable() | |