Spaces:
Runtime error
Runtime error
import os | |
# Init with fake key | |
if 'OPENAI_API_KEY' not in os.environ: | |
os.environ['OPENAI_API_KEY'] = 'none' | |
import pandas as pd | |
import streamlit as st | |
from IPython.core.display import HTML | |
from PIL import Image | |
from langchain.callbacks import wandb_tracing_enabled | |
from chemcrow.agents import ChemCrow, make_tools | |
from chemcrow.frontend.streamlit_callback_handler import \ | |
StreamlitCallbackHandlerChem | |
from utils import oai_key_isvalid | |
from dotenv import load_dotenv | |
load_dotenv() | |
ss = st.session_state | |
ss.prompt = None | |
icon = Image.open('assets/logo0.png') | |
st.set_page_config( | |
page_title="ChemCrow", | |
page_icon = icon | |
) | |
# Set width of sidebar | |
st.markdown( | |
""" | |
<style> | |
[data-testid="stSidebar"][aria-expanded="true"]{ | |
min-width: 450px; | |
max-width: 450px; | |
} | |
""", | |
unsafe_allow_html=True, | |
) | |
agent = ChemCrow( | |
model='gpt-4', | |
temp=0.1, | |
openai_api_key=ss.get('api_key'), | |
api_keys={ | |
'rxn4chem':st.secrets['RXN4CHEM_API_KEY'] | |
} | |
).agent_executor | |
tools = agent.tools | |
tool_list = pd.Series( | |
{f"✅ {t.name}":t.description for t in tools} | |
).reset_index() | |
tool_list.columns = ['Tool', 'Description'] | |
def on_api_key_change(): | |
api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY') | |
# Check if key is valid | |
if not oai_key_isvalid(api_key): | |
st.write("Please input a valid OpenAI API key.") | |
pre_prompts = [ | |
'What is the molecular weight of sugar', | |
'Can I safely mix caffeine and sodium hydroxide?', | |
'How is safinamide synthesized?', | |
'How similar is morphine to heroin?' | |
] | |
def run_prompt(prompt): | |
st.chat_message("user").write(prompt) | |
with st.chat_message("assistant"): | |
st_callback = StreamlitCallbackHandlerChem( | |
st.container(), | |
max_thought_containers = 4, | |
collapse_completed_thoughts = False, | |
output_placeholder=ss | |
) | |
with wandb_tracing_enabled(): | |
response = agent.run(prompt, callbacks=[st_callback]) | |
st.write(response) | |
# sidebar | |
with st.sidebar: | |
chemcrow_logo = Image.open('assets/chemcrow-logo-bold-new.png') | |
st.image(chemcrow_logo) | |
# Input OpenAI api key | |
st.markdown('Input your OpenAI API key.') | |
st.text_input( | |
'OpenAI API key', | |
type='password', | |
key='api_key', | |
on_change=on_api_key_change, | |
label_visibility="collapsed" | |
) | |
# Display prompt examples | |
st.markdown('# What can I ask?') | |
cols = st.columns(2) | |
with cols[0]: | |
st.button( | |
pre_prompts[0], | |
on_click=lambda: run_prompt(pre_prompts[0]), | |
) | |
st.button( | |
pre_prompts[1], | |
on_click=lambda: run_prompt(pre_prompts[1]), | |
) | |
with cols[1]: | |
st.button( | |
pre_prompts[2], | |
on_click=lambda: run_prompt(pre_prompts[2]), | |
) | |
st.button( | |
pre_prompts[3], | |
on_click=lambda: run_prompt(pre_prompts[3]), | |
) | |
# Display available tools | |
st.markdown(f"# {len(tool_list)} available tools") | |
st.dataframe( | |
tool_list, | |
use_container_width=True, | |
hide_index=True, | |
height=200 | |
) | |
prompt = None | |
# Determine the prompt to use | |
if user_input := st.chat_input(): | |
run_prompt(user_input) | |