import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions.categorical import Categorical import torch_ac from utils.other import init_params class MMMemoryMultiHeadedACModel(nn.Module, torch_ac.RecurrentACModel): def __init__(self, obs_space, action_space, use_memory=False, use_text=False, use_dialogue=False): super().__init__() # Decide which components are enabled self.use_text = use_text self.use_dialogue = use_dialogue self.use_memory = use_memory if not self.use_memory: raise ValueError("You should not be using this model. Use MultiHeadedACModel instead") if self.use_text: raise ValueError("You should not use text but dialogue.") # multi dim if action_space.shape == (): raise ValueError("The action space is not multi modal. Use ACModel instead.") self.n_primitive_actions = action_space.nvec[0] + 1 # for talk self.talk_action = int(self.n_primitive_actions) - 1 self.n_utterance_actions = action_space.nvec[1:] # Define image embedding self.image_conv = nn.Sequential( nn.Conv2d(3, 16, (2, 2)), nn.ReLU(), nn.MaxPool2d((2, 2)), nn.Conv2d(16, 32, (2, 2)), nn.ReLU(), nn.Conv2d(32, 64, (2, 2)), nn.ReLU() ) n = obs_space["image"][0] m = obs_space["image"][1] self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64 if self.use_text or self.use_dialogue: self.word_embedding_size = 32 self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size) # Define text embedding if self.use_text: self.text_embedding_size = 128 self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True) # Define dialogue embedding if self.use_dialogue: self.dialogue_embedding_size = 128 self.dialogue_rnn = nn.GRU(self.word_embedding_size, self.dialogue_embedding_size, batch_first=True) # Resize image embedding self.embedding_size = self.image_embedding_size if self.use_text: self.embedding_size += self.text_embedding_size if self.use_dialogue: self.embedding_size += self.dialogue_embedding_size if self.use_memory: self.memory_rnn = nn.LSTMCell(self.embedding_size, self.embedding_size) # Define actor's model self.actor = nn.Sequential( nn.Linear(self.embedding_size, 64), nn.Tanh(), nn.Linear(64, self.n_primitive_actions) ) self.talker = nn.ModuleList([ nn.Sequential( nn.Linear(self.embedding_size, 64), nn.Tanh(), nn.Linear(64, n) ) for n in self.n_utterance_actions]) # Define critic's model self.critic = nn.Sequential( nn.Linear(self.embedding_size, 64), nn.Tanh(), nn.Linear(64, 1) ) # Initialize parameters correctly self.apply(init_params) @property def memory_size(self): return 2*self.semi_memory_size @property def semi_memory_size(self): return self.embedding_size def forward(self, obs, memory): x = obs.image.transpose(1, 3).transpose(2, 3) x = self.image_conv(x) batch_size = x.shape[0] x = x.reshape(batch_size, -1) embedding = x if self.use_text: embed_text = self._get_embed_text(obs.text) embedding = torch.cat((embedding, embed_text), dim=1) if self.use_dialogue: embed_dial = self._get_embed_dialogue(obs.dialogue) embedding = torch.cat((embedding, embed_dial), dim=1) if self.use_memory: hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) hidden = self.memory_rnn(embedding, hidden) embedding = hidden[0] memory = torch.cat(hidden, dim=1) x = self.actor(embedding) primitive_actions_dist = Categorical(logits=F.log_softmax(x, dim=1)) x = self.critic(embedding) value = x.squeeze(1) utterance_actions_dists = [ Categorical(logits=F.log_softmax( tal(embedding), dim=1, )) for tal in self.talker ] dist = [primitive_actions_dist] + utterance_actions_dists return dist, value, memory def sample_action(self, dist): return torch.stack([d.sample() for d in dist], dim=1) def calculate_log_probs(self, dist, action): return torch.stack([d.log_prob(action[:, i]) for i, d in enumerate(dist)], dim=1) def calculate_action_masks(self, action): talk_mask = action[:, 0] == self.talk_action mask = torch.stack( (torch.ones_like(talk_mask), talk_mask, talk_mask), dim=1).detach() assert action.shape == mask.shape return mask def construct_final_action(self, action): act_mask = action[:, 0] != self.n_primitive_actions - 1 nan_mask = np.array([ np.array([1, np.nan, np.nan]) if t else np.array([np.nan, 1, 1]) for t in act_mask ]) action = nan_mask*action return action def _get_embed_text(self, text): _, hidden = self.text_rnn(self.word_embedding(text)) return hidden[-1] def _get_embed_dialogue(self, dial): _, hidden = self.dialogue_rnn(self.word_embedding(dial)) return hidden[-1]