Spaces:
Running
Running
Adding textual observations to demo
Browse files- web_demo/app.py +25 -5
web_demo/app.py
CHANGED
@@ -9,6 +9,8 @@ import gym_minigrid
|
|
9 |
import numpy as np
|
10 |
from gym_minigrid.window import Window
|
11 |
|
|
|
|
|
12 |
import os
|
13 |
|
14 |
app = Flask(__name__)
|
@@ -46,11 +48,27 @@ global env_label
|
|
46 |
env_label = list(env_label_to_env_name.keys())[0]
|
47 |
env_name = env_label_to_env_name[env_label]
|
48 |
|
|
|
|
|
|
|
49 |
global mask_unobserved
|
50 |
mask_unobserved = False
|
51 |
|
52 |
env = gym.make(env_name)
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def update_tree():
|
55 |
selected_parameters = env.current_env.parameters
|
56 |
selected_env_type = selected_parameters["Env_type"]
|
@@ -116,10 +134,9 @@ def set_mask_unobserved():
|
|
116 |
def update_image():
|
117 |
action_name = request.form.get('action')
|
118 |
|
119 |
-
|
120 |
if action_name == 'done':
|
121 |
# reset the env and update the tree image
|
122 |
-
obs = env.reset()
|
123 |
update_tree()
|
124 |
|
125 |
else:
|
@@ -145,21 +162,24 @@ def update_image():
|
|
145 |
|
146 |
obs, reward, done, info = env.step(action)
|
147 |
|
|
|
148 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
149 |
image_data = np_img_to_base64(image)
|
150 |
|
151 |
-
|
152 |
-
bubble_text = format_bubble_text(env.current_env.full_conversation)
|
153 |
|
154 |
return jsonify({'image_data': image_data, "bubble_text": bubble_text})
|
155 |
|
156 |
|
|
|
157 |
@app.route('/', methods=['GET', 'POST'])
|
158 |
def index():
|
|
|
159 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
160 |
image_data = np_img_to_base64(image)
|
161 |
|
162 |
-
bubble_text = format_bubble_text(env.current_env.full_conversation)
|
|
|
163 |
|
164 |
available_env_labels = env_label_to_env_name.keys()
|
165 |
|
|
|
9 |
import numpy as np
|
10 |
from gym_minigrid.window import Window
|
11 |
|
12 |
+
from textworld_utils.utils import generate_text_obs
|
13 |
+
|
14 |
import os
|
15 |
|
16 |
app = Flask(__name__)
|
|
|
48 |
env_label = list(env_label_to_env_name.keys())[0]
|
49 |
env_name = env_label_to_env_name[env_label]
|
50 |
|
51 |
+
|
52 |
+
textworld_envs = ["SocialAI-AsocialBoxInformationSeekingParamEnv-v1", "SocialAI-ColorBoxesLLMCSParamEnv-v1"]
|
53 |
+
|
54 |
global mask_unobserved
|
55 |
mask_unobserved = False
|
56 |
|
57 |
env = gym.make(env_name)
|
58 |
|
59 |
+
|
60 |
+
def create_bubble_text(env_name, obs, info, full_conversation, textworld_envs):
|
61 |
+
if env_name in textworld_envs:
|
62 |
+
text_obs = generate_text_obs(obs, info)
|
63 |
+
# bubble_text = "Textworld state:\n" + text_obs
|
64 |
+
bubble_text = text_obs
|
65 |
+
|
66 |
+
else:
|
67 |
+
bubble_text = format_bubble_text(full_conversation)
|
68 |
+
|
69 |
+
return bubble_text
|
70 |
+
|
71 |
+
|
72 |
def update_tree():
|
73 |
selected_parameters = env.current_env.parameters
|
74 |
selected_env_type = selected_parameters["Env_type"]
|
|
|
134 |
def update_image():
|
135 |
action_name = request.form.get('action')
|
136 |
|
|
|
137 |
if action_name == 'done':
|
138 |
# reset the env and update the tree image
|
139 |
+
obs, info = env.reset(with_info=True)
|
140 |
update_tree()
|
141 |
|
142 |
else:
|
|
|
162 |
|
163 |
obs, reward, done, info = env.step(action)
|
164 |
|
165 |
+
|
166 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
167 |
image_data = np_img_to_base64(image)
|
168 |
|
169 |
+
bubble_text = create_bubble_text(env_name, obs, info, env.current_env.full_conversation, textworld_envs)
|
|
|
170 |
|
171 |
return jsonify({'image_data': image_data, "bubble_text": bubble_text})
|
172 |
|
173 |
|
174 |
+
|
175 |
@app.route('/', methods=['GET', 'POST'])
|
176 |
def index():
|
177 |
+
obs, info = env.reset(with_info=True)
|
178 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
179 |
image_data = np_img_to_base64(image)
|
180 |
|
181 |
+
# bubble_text = format_bubble_text(env.current_env.full_conversation)
|
182 |
+
bubble_text = create_bubble_text(env_name, obs, info, env.current_env.full_conversation, textworld_envs)
|
183 |
|
184 |
available_env_labels = env_label_to_env_name.keys()
|
185 |
|