File size: 2,553 Bytes
4930b0c
 
 
 
 
 
 
 
 
 
 
3b2cf81
 
 
 
4930b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b2cf81
4930b0c
 
 
 
 
 
 
 
3b2cf81
 
 
 
4930b0c
 
 
 
 
 
 
3b2cf81
 
 
 
 
 
 
4930b0c
e110d82
3b2cf81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import streamlit.components.v1 as components

from scienceworld import ScienceWorldEnv

description = """
[Project Page](https://sciworld.apps.allenai.org) | [ArXiv Paper](https://arxiv.org/abs/2203.07540) | [Github Repo](https://github.com/allenai/ScienceWorld)
"""
st.title("ScienceWorld Demo")
st.markdown(description)

# Apply custom CSS.
with open('style.css')as f:
    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

env = st.session_state.get("env")
if env is None:
    env = ScienceWorldEnv("")
    st.session_state["env"] = env

seed = st.session_state.get("seed")
obs = st.session_state.get("obs")
infos = st.session_state.get("infos")
history = st.session_state.get("history")
if history is None:
    history = []
    st.session_state["history"] = history

def clear_history():
    history.clear()


with st.sidebar:
    st.title("ScienceWorld Demo")
    st.markdown(description)
    task = st.selectbox("Task:", env.getTaskNames(), on_change=clear_history)

if len(history) == 0:
    env.load(task, 0, "")
    obs, infos = env.reset()
    st.session_state["obs"] = obs
    st.session_state["infos"] = infos
    history.append(("", env.getTaskDescription()))
    history.append(("look around", obs))

def step():
    act = st.session_state.action
    if act:
        obs, reward, done, infos = env.step(act)
        history.append((act, obs))
        st.session_state["obs"] = obs
        st.session_state["infos"] = infos

        if act == "reset":
            clear_history()


with st.sidebar:
    st.warning(env.getTaskDescription())
    st.success(f"Score:  {infos['score']}")

    valid_actions = [""] + sorted(infos["valid"])
    if infos['score'] == 100:
        valid_actions = ["", "reset"]

    # act = st.selectbox('Action:', options=valid_actions, index=0, on_change=step, key="action")

for act, obs in history:
    if act:
        st.write("> " + act)

    if obs:
        st.info(obs.replace('\n\t', '\n- '))

act = st.selectbox('Action:', options=valid_actions, index=0, on_change=step, key="action")

st.warning(f"Current score:  {infos['score']} out of 100")

if infos['score'] == 100:
    with st.sidebar:
        st.balloons()

    st.success("Congratulations! You have completed the task.")


# Auto scroll at the bottom of the page.
components.html(
f"""
    <p>{st.session_state.obs}</p>
    <script>
        window.parent.document.querySelector('section.main').scrollTo(0, window.parent.document.querySelector('section.main').scrollHeight);
    </script>

# """,
height=0
)