SocialAISchool / utils /format.py
grg's picture
Cleaned old git history
be5548b
import os
import json
import numpy
import re
import torch
import torch_ac
import gym
import utils
def get_obss_preprocessor(obs_space, text=None, dialogue_current=None, dialogue_history=None, custom_image_preprocessor=None, custom_image_space_preprocessor=None):
# Check if obs_space is an image space
if isinstance(obs_space, gym.spaces.Box):
obs_space = {"image": obs_space.shape}
def preprocess_obss(obss, device=None):
assert custom_image_preprocessor is None
return torch_ac.DictList({
"image": preprocess_images(obss, device=device)
})
# Check if it is a MiniGrid observation space
elif isinstance(obs_space, gym.spaces.Dict) and list(obs_space.spaces.keys()) == ["image"]:
assert (custom_image_preprocessor is None) == (custom_image_space_preprocessor is None)
image_obs_space = obs_space.spaces["image"].shape
if custom_image_preprocessor:
image_obs_space = custom_image_space_preprocessor(image_obs_space)
obs_space = {"image": image_obs_space, "text": 100}
# must be specified in this case
if text is None:
raise ValueError("text argument must be specified.")
if dialogue_current is None:
raise ValueError("dialogue current argument must be specified.")
if dialogue_history is None:
raise ValueError("dialogue history argument must be specified.")
vocab = Vocabulary(obs_space["text"])
def preprocess_obss(obss, device=None):
if custom_image_preprocessor is None:
D = {
"image": preprocess_images([obs["image"] for obs in obss], device=device)
}
else:
D = {
"image": custom_image_preprocessor([obs["image"] for obs in obss], device=device)
}
if dialogue_current:
D["utterance"] = preprocess_texts([obs["utterance"] for obs in obss], vocab, device=device)
if dialogue_history:
D["utterance_history"] = preprocess_texts([obs["utterance_history"] for obs in obss], vocab, device=device)
if text:
D["text"] = preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
return torch_ac.DictList(D)
preprocess_obss.vocab = vocab
else:
raise ValueError("Unknown observation space: " + str(obs_space))
return obs_space, preprocess_obss
def ride_ref_image_space_preprocessor(image_space):
return image_space
def ride_ref_image_preprocessor(images, device=None):
# Bug of Pytorch: very slow if not first converted to numpy array
images = numpy.array(images)
# grid dimensions
size = images.shape[1]
assert size == images.shape[2]
# assert that 1, 2 are absolute cooridnates
# assert images[:,:,:,1].max() <= size
# assert images[:,:,:,2].max() <= size
assert images[:,:,:,1].min() >= 0
assert images[:,:,:,2].min() >= 0
#
# # 0, 1, 2 -> door state
# assert all([e in set([0, 1, 2]) for e in numpy.unique(images[:, :, :, 4].reshape(-1))])
#
# only keep the (obj id, colors, state) -> multiply others by 0
# print(images[:, :, :, 1].max())
images[:, :, :, 1] *= 0
images[:, :, :, 2] *= 0
assert images.shape[-1] == 5
return torch.tensor(images, device=device, dtype=torch.float)
def preprocess_images(images, device=None):
# Bug of Pytorch: very slow if not first converted to numpy array
images = numpy.array(images)
return torch.tensor(images, device=device, dtype=torch.float)
def preprocess_texts(texts, vocab, device=None):
var_indexed_texts = []
max_text_len = 0
for text in texts:
tokens = re.findall("([a-z]+)", text.lower())
var_indexed_text = numpy.array([vocab[token] for token in tokens])
var_indexed_texts.append(var_indexed_text)
max_text_len = max(len(var_indexed_text), max_text_len)
indexed_texts = numpy.zeros((len(texts), max_text_len))
for i, indexed_text in enumerate(var_indexed_texts):
indexed_texts[i, :len(indexed_text)] = indexed_text
return torch.tensor(indexed_texts, device=device, dtype=torch.long)
class Vocabulary:
"""A mapping from tokens to ids with a capacity of `max_size` words.
It can be saved in a `vocab.json` file."""
def __init__(self, max_size):
self.max_size = max_size
self.vocab = {}
def load_vocab(self, vocab):
self.vocab = vocab
def __getitem__(self, token):
if not token in self.vocab.keys():
if len(self.vocab) >= self.max_size:
raise ValueError("Maximum vocabulary capacity reached")
self.vocab[token] = len(self.vocab) + 1
return self.vocab[token]