grg's picture
typo fix
f9aa008
from flask import Flask, render_template, request, session, redirect, url_for, send_from_directory, jsonify
from PIL import Image
import io
import base64
import time
import gym
import gym_minigrid
import numpy as np
from gym_minigrid.window import Window
from gym_minigrid.curriculums import SelectedParametersOrRandomCurriculum
from textworld_utils.utils import generate_text_obs
import os
app = Flask(__name__)
env_types = ["Information_seeking", "Collaboration", "AppleStealing"]
env_label_to_env_name = {
"Full SocialAI environment": "SocialAI-SocialAIParamEnv-v1", # all
"Pointing (Train)": "SocialAI-EPointingHeldoutDoorsTrainInformationSeekingParamEnv-v1", # Pointing Train
"Pointing (Test)": "SocialAI-EPointingDoorsTestInformationSeekingParamEnv-v1", # Pointing Test
"Role Reversal Single Role B (Pretrain - experimental)": "SocialAI-MarblePassBCollaborationParamEnv-v1",
"Role Reversal Single Asocial (Pretrain - control)": "SocialAI-AsocialMarbleCollaborationParamEnv-v1",
"Role Reversal Group Role B (Pretrain - experimental)": "SocialAI-RoleReversalGroupExperimentalCollaborationParamEnv-v1",
"Role Reversal Group Asocial (Pretrain - control)": "SocialAI-RoleReversalGroupControlCollaborationParamEnv-v1",
"Role Reversal Role A (Finetune - test)": "SocialAI-MarblePassACollaborationParamEnv-v1",
"Imitation (Train)": "SocialAI-EEmulationNoDistrInformationSeekingParamEnv-v1",
"Imitation (Test)": "SocialAI-EEmulationNoDistrDoorsInformationSeekingParamEnv-v1",
"AsocialBox (textworld)": "SocialAI-AsocialBoxInformationSeekingParamEnv-v1",
"ColorBoxes (textworld)": "SocialAI-ColorBoxesLLMCSParamEnv-v1",
"Language Color (Train)": "SocialAI-ELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
"Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
"Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1",
"Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1",
"Joint Attention Language Color (Train)": "SocialAI-JAELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
"Joint Attention Language Color (Test)": "SocialAI-JAELangColorDoorsTestInformationSeekingParamEnv-v1",
"Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1",
"Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1",
"Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1",
"Scaffolding/Formats (test)":"SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1",
}
available_env_labels = [
"Full SocialAI environment",
"---- Pointing ----",
"Pointing (Train)",
"Pointing (Test)",
"---- Role Reversal ----",
"Role Reversal Single Role B (Pretrain - experimental)",
"Role Reversal Single Asocial (Pretrain - control)",
"Role Reversal Group Role B (Pretrain - experimental)",
"Role Reversal Group Asocial (Pretrain - control)",
"Role Reversal Role A (Finetune - test)",
"---- Imitation ----",
"Imitation (Train)",
"Imitation (Test)",
"---- TextWorld (LLM experiments) ----",
"AsocialBox (textworld)",
"ColorBoxes (textworld)",
"---- Language Color ----",
"Language Color (Train)",
"Language Color (Test)",
"---- Language Feedback ----",
"Language Feedback (Train)",
"Language Feedback (Test)",
"---- Joint Attention Language Color ----",
"Joint Attention Language Color (Train)",
"Joint Attention Language Color (Test)",
"---- Apple Stealing ----",
"Apple stealing",
"Apple stealing (Occlusions)",
"---- Scaffolding/Formats ----",
"Scaffolding (train - scaf_8: Phase 1)",
"Scaffolding/Formats (test)"
]
assert all([l in available_env_labels for l in env_label_to_env_name.keys()])
global env_name
global env_label
env_label = list(env_label_to_env_name.keys())[0]
env_name = env_label_to_env_name[env_label]
textworld_envs = ["SocialAI-AsocialBoxInformationSeekingParamEnv-v1", "SocialAI-ColorBoxesLLMCSParamEnv-v1"]
global mask_unobserved
mask_unobserved = False
global textual_observations
textual_observations = False
env = gym.make(env_name)
global obs, info
obs, info = env.reset(with_info=True)
def get_parameter_options(env):
return env.get_potential_params()
def create_bubble_text(obs, info, full_conversation, textual_observations):
if textual_observations:
bubble_text = "Textual observation\n\n"+ \
generate_text_obs(obs, info)
else:
bubble_text = full_conversation
bubble_text = format_bubble_text(bubble_text)
return bubble_text
def update_tree():
selected_parameters = env.current_env.parameters
print("sel param:", selected_parameters)
selected_env_type = selected_parameters["Env_type"]
assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}"
folded_nodes = [e for e in env_types if e != selected_env_type]
env.parameter_tree.draw_tree(
filename="./web_demo/static/current_tree",
ignore_labels=["Num_of_colors"],
selected_parameters=selected_parameters,
folded_nodes=folded_nodes
)
update_tree()
def np_img_to_base64(np_image):
image = Image.fromarray(np_image)
img_io = io.BytesIO()
image.save(img_io, 'JPEG', quality=70)
img_io.seek(0)
return base64.b64encode(img_io.getvalue()).decode('utf-8')
def format_bubble_text(text):
lines = text.split("\n")
if len(lines) > 10:
# Keep the first line, add "....", and then append the last 8 lines
lines = [lines[0], "...."] + lines[-8:]
return "\n".join(lines)
@app.route('/set_env_params', methods=['POST'])
def set_env_params():
global env
selected_params_ids = request.get_json()
selected_parameters = {
env.parameter_tree.get_node_for_id(k): env.parameter_tree.get_node_for_id(v) for k,v in selected_params_ids.items()
}
global obs, info
selected_parameters_curriuclum = SelectedParametersOrRandomCurriculum(selected_parameters)
obs, info = env.reset(with_info=True, ACL=selected_parameters_curriuclum)
update_tree() # Update the tree for the new environment
return jsonify({"success": True}), 200
# return redirect(url_for('index')) # Redirect back to the main page
@app.route('/set_env', methods=['POST'])
def set_env():
global env_name # Declare the variable as global to modify it
global env_label # Declare the variable as global to modify it
env_label = request.form.get('env_label') # Get the selected env_name from the form
env_name = env_label_to_env_name[env_label]
global env # Declare the env variable as global to modify it
env = gym.make(env_name) # Initialize the environment with the new name
global obs, info
obs, info = env.reset(with_info=True)
update_tree() # Update the tree for the new environment
return redirect(url_for('index')) # Redirect back to the main page
@app.route('/set_mask_unobserved', methods=['POST'])
def set_mask_unobserved():
global mask_unobserved
mask_unobserved = request.form.get('mask_unobserved') == 'true'
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
image_data = np_img_to_base64(image)
return jsonify({'image_data': image_data})
@app.route('/set_textual_observations', methods=['POST'])
def set_textual_observations():
global textual_observations
textual_observations = request.form.get('textual_observations') == 'true'
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)
return jsonify({"bubble_text": bubble_text})
@app.route('/perform_action', methods=['POST'])
def perform_action():
action_name = request.form.get('action')
global obs, info
if action_name == 'done':
# reset the env and update the tree image
obs, info = env.reset(with_info=True)
done = False
update_tree()
else:
if action_name == "speak":
action_template = request.form.get('template')
action_word = request.form.get('word')
temp_ind, word_ind = env.grammar.get_action(action_template, action_word)
action = [np.nan, temp_ind, word_ind]
elif action_name == 'left':
action = [int(env.actions.left), np.nan, np.nan]
elif action_name == 'right':
action = [int(env.actions.right), np.nan, np.nan]
elif action_name == 'forward':
action = [int(env.actions.forward), np.nan, np.nan]
elif action_name == 'toggle':
action = [int(env.actions.toggle), np.nan, np.nan]
elif action_name == 'noop':
action = [np.nan, np.nan, np.nan]
else:
action = [np.nan, np.nan, np.nan]
obs, reward, done, info = env.step(action)
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
image_data = np_img_to_base64(image)
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)
return jsonify({
'image_data': image_data,
'success': info["success"],
'done': done,
'bubble_text': bubble_text
})
@app.route('/', methods=['GET', 'POST'])
def index():
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
image_data = np_img_to_base64(image)
# bubble_text = format_bubble_text(env.current_env.full_conversation)
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)
grammar_templates = env.grammar.templates
grammar_words = env.grammar.things
return render_template(
'index.html',
image_data=image_data,
bubble_text=bubble_text,
mask_unobserved=mask_unobserved,
timestamp=time.time(),
available_env_labels=available_env_labels,
current_env_label=env_label,
grammar_templates=grammar_templates,
grammar_words=grammar_words,
parameter_options=get_parameter_options(env),
current_parameters=env.current_params
)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)