Spaces:
Running
Running
import os | |
import json | |
import numpy | |
import re | |
import torch | |
import torch_ac | |
import gym | |
import utils | |
def get_obss_preprocessor(obs_space, text=None, dialogue_current=None, dialogue_history=None, custom_image_preprocessor=None, custom_image_space_preprocessor=None): | |
# Check if obs_space is an image space | |
if isinstance(obs_space, gym.spaces.Box): | |
obs_space = {"image": obs_space.shape} | |
def preprocess_obss(obss, device=None): | |
assert custom_image_preprocessor is None | |
return torch_ac.DictList({ | |
"image": preprocess_images(obss, device=device) | |
}) | |
# Check if it is a MiniGrid observation space | |
elif isinstance(obs_space, gym.spaces.Dict) and list(obs_space.spaces.keys()) == ["image"]: | |
assert (custom_image_preprocessor is None) == (custom_image_space_preprocessor is None) | |
image_obs_space = obs_space.spaces["image"].shape | |
if custom_image_preprocessor: | |
image_obs_space = custom_image_space_preprocessor(image_obs_space) | |
obs_space = {"image": image_obs_space, "text": 100} | |
# must be specified in this case | |
if text is None: | |
raise ValueError("text argument must be specified.") | |
if dialogue_current is None: | |
raise ValueError("dialogue current argument must be specified.") | |
if dialogue_history is None: | |
raise ValueError("dialogue history argument must be specified.") | |
vocab = Vocabulary(obs_space["text"]) | |
def preprocess_obss(obss, device=None): | |
if custom_image_preprocessor is None: | |
D = { | |
"image": preprocess_images([obs["image"] for obs in obss], device=device) | |
} | |
else: | |
D = { | |
"image": custom_image_preprocessor([obs["image"] for obs in obss], device=device) | |
} | |
if dialogue_current: | |
D["utterance"] = preprocess_texts([obs["utterance"] for obs in obss], vocab, device=device) | |
if dialogue_history: | |
D["utterance_history"] = preprocess_texts([obs["utterance_history"] for obs in obss], vocab, device=device) | |
if text: | |
D["text"] = preprocess_texts([obs["mission"] for obs in obss], vocab, device=device) | |
return torch_ac.DictList(D) | |
preprocess_obss.vocab = vocab | |
else: | |
raise ValueError("Unknown observation space: " + str(obs_space)) | |
return obs_space, preprocess_obss | |
def ride_ref_image_space_preprocessor(image_space): | |
return image_space | |
def ride_ref_image_preprocessor(images, device=None): | |
# Bug of Pytorch: very slow if not first converted to numpy array | |
images = numpy.array(images) | |
# grid dimensions | |
size = images.shape[1] | |
assert size == images.shape[2] | |
# assert that 1, 2 are absolute cooridnates | |
# assert images[:,:,:,1].max() <= size | |
# assert images[:,:,:,2].max() <= size | |
assert images[:,:,:,1].min() >= 0 | |
assert images[:,:,:,2].min() >= 0 | |
# | |
# # 0, 1, 2 -> door state | |
# assert all([e in set([0, 1, 2]) for e in numpy.unique(images[:, :, :, 4].reshape(-1))]) | |
# | |
# only keep the (obj id, colors, state) -> multiply others by 0 | |
# print(images[:, :, :, 1].max()) | |
images[:, :, :, 1] *= 0 | |
images[:, :, :, 2] *= 0 | |
assert images.shape[-1] == 5 | |
return torch.tensor(images, device=device, dtype=torch.float) | |
def preprocess_images(images, device=None): | |
# Bug of Pytorch: very slow if not first converted to numpy array | |
images = numpy.array(images) | |
return torch.tensor(images, device=device, dtype=torch.float) | |
def preprocess_texts(texts, vocab, device=None): | |
var_indexed_texts = [] | |
max_text_len = 0 | |
for text in texts: | |
tokens = re.findall("([a-z]+)", text.lower()) | |
var_indexed_text = numpy.array([vocab[token] for token in tokens]) | |
var_indexed_texts.append(var_indexed_text) | |
max_text_len = max(len(var_indexed_text), max_text_len) | |
indexed_texts = numpy.zeros((len(texts), max_text_len)) | |
for i, indexed_text in enumerate(var_indexed_texts): | |
indexed_texts[i, :len(indexed_text)] = indexed_text | |
return torch.tensor(indexed_texts, device=device, dtype=torch.long) | |
class Vocabulary: | |
"""A mapping from tokens to ids with a capacity of `max_size` words. | |
It can be saved in a `vocab.json` file.""" | |
def __init__(self, max_size): | |
self.max_size = max_size | |
self.vocab = {} | |
def load_vocab(self, vocab): | |
self.vocab = vocab | |
def __getitem__(self, token): | |
if not token in self.vocab.keys(): | |
if len(self.vocab) >= self.max_size: | |
raise ValueError("Maximum vocabulary capacity reached") | |
self.vocab[token] = len(self.vocab) + 1 | |
return self.vocab[token] | |