File size: 3,321 Bytes
97cdf09
c5a1bbd
e4ff6d1
b0a8223
f7e31d3
b0a8223
97cdf09
 
 
 
a9beafd
04eb2a2
97cdf09
 
167e68e
97cdf09
 
 
 
 
453acb7
97cdf09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04eb2a2
 
 
f3e35a0
 
 
 
04eb2a2
 
 
 
 
 
 
 
97cdf09
167e68e
 
 
03099f1
167e68e
 
453acb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97cdf09
 
 
 
 
 
 
167e68e
 
 
 
 
 
 
97cdf09
453acb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97cdf09
453acb7
04eb2a2
 
 
 
453acb7
04eb2a2
be84066
 
97cdf09
453acb7
97cdf09
453acb7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os

# Init with fake key
if 'OPENAI_API_KEY' not in os.environ:
    os.environ['OPENAI_API_KEY'] = 'none'

import pandas as pd
import streamlit as st
from IPython.core.display import HTML
from PIL import Image
from langchain.callbacks import wandb_tracing_enabled
from chemcrow.agents import ChemCrow, make_tools
from chemcrow.frontend.streamlit_callback_handler import \
    StreamlitCallbackHandlerChem
from utils import oai_key_isvalid

from dotenv import load_dotenv

load_dotenv()
ss = st.session_state
ss.prompt = None

icon = Image.open('assets/logo0.png')
st.set_page_config(
    page_title="ChemCrow",
    page_icon = icon
)

# Set width of sidebar
st.markdown(
    """
    <style>
    [data-testid="stSidebar"][aria-expanded="true"]{
        min-width: 450px;
        max-width: 450px;
    }
    """,
    unsafe_allow_html=True,
)

agent = ChemCrow(
    model='gpt-4',
    temp=0.1,
    openai_api_key=ss.get('api_key'),
    api_keys={
        'rxn4chem':st.secrets['RXN4CHEM_API_KEY']
    }
).agent_executor

tools = agent.tools

tool_list = pd.Series(
    {f"✅ {t.name}":t.description for t in tools}
).reset_index()
tool_list.columns = ['Tool', 'Description']

def on_api_key_change():
    api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
    # Check if key is valid
    if not oai_key_isvalid(api_key):
        st.write("Please input a valid OpenAI API key.")

pre_prompts = [
    'What is the molecular weight of sugar',
    'Can I safely mix caffeine and sodium hydroxide?',
    'How is safinamide synthesized?',
    'How similar is morphine to heroin?'
]


def run_prompt(prompt):

    st.chat_message("user").write(prompt)
    with st.chat_message("assistant"):
        st_callback = StreamlitCallbackHandlerChem(
            st.container(),
            max_thought_containers = 4,
            collapse_completed_thoughts = False,
            output_placeholder=ss
        )
        with wandb_tracing_enabled():
            response = agent.run(prompt, callbacks=[st_callback])
            st.write(response)


# sidebar
with st.sidebar:
    chemcrow_logo = Image.open('assets/chemcrow-logo-bold-new.png')
    st.image(chemcrow_logo)

    # Input OpenAI api key
    st.markdown('Input your OpenAI API key.')
    st.text_input(
        'OpenAI API key',
        type='password',
        key='api_key',
        on_change=on_api_key_change,
        label_visibility="collapsed"
    )

    # Display prompt examples
    st.markdown('# What can I ask?')
    cols = st.columns(2)
    with cols[0]:
        st.button(
            pre_prompts[0],
            on_click=lambda: run_prompt(pre_prompts[0]),
        )
        st.button(
            pre_prompts[1],
            on_click=lambda: run_prompt(pre_prompts[1]),
        )
    with cols[1]:
        st.button(
            pre_prompts[2],
            on_click=lambda: run_prompt(pre_prompts[2]),
        )
        st.button(
            pre_prompts[3],
            on_click=lambda: run_prompt(pre_prompts[3]),
        )


    # Display available tools
    st.markdown(f"# {len(tool_list)} available tools")
    st.dataframe(
        tool_list,
        use_container_width=True,
        hide_index=True,
        height=200
    )



prompt = None

# Determine the prompt to use
if user_input := st.chat_input():
    run_prompt(user_input)