from torch import nn import torch import numpy as np from transformers import BertPreTrainedModel from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions from transformers.models.bert.modeling_bert import BertPooler, BertEncoder class MetaQA_Model(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = MetaQABertModel(config) self.num_agents = config.num_agents self.dropout = nn.Dropout(config.hidden_dropout_prob) self.list_MoSeN = nn.ModuleList([nn.Linear(config.hidden_size, 1) for i in range(self.num_agents)]) self.input_size_ans_sel = 1 + config.hidden_size interm_size = int(config.hidden_size/2) self.ans_sel = nn.Sequential(nn.Linear(self.input_size_ans_sel, interm_size), nn.ReLU(), nn.Dropout(config.hidden_dropout_prob), nn.Linear(interm_size, 2)) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ans_sc=None, agent_sc=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ans_sc=ans_sc, agent_sc=agent_sc, ) # domain classification pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) list_domains_logits = [] for MoSeN in self.list_MoSeN: domain_logits = MoSeN(pooled_output) list_domains_logits.append(domain_logits) domain_logits = torch.stack(list_domains_logits) # shape = (num_agents, batch_size, 1) # we have to transpose the shape to (batch_size, num_agents, 1) domain_logits = domain_logits.transpose(0,1) # ans classifier sequence_output = outputs[0] # (batch_size, seq_len, hidden_size) # select the [RANK] token embeddings idx_rank = (input_ids == 1).nonzero() # (batch_size x num_agents, 2) idx_rank = idx_rank[:,1].view(-1, self.num_agents) list_emb = [] for i in range(idx_rank.shape[0]): rank_emb = sequence_output[i][idx_rank[i], :] # rank shape = (1, hidden_size) list_emb.append(rank_emb) rank_emb = torch.stack(list_emb) rank_emb = self.dropout(rank_emb) rank_emb = torch.cat((rank_emb, domain_logits), dim=2) # rank emb shape = (batch_size, num_agents, hidden_size+1) logits = self.ans_sel(rank_emb) # (batch_size, num_agents, 2) if not return_dict: output = (logits,) + outputs[2:] return output return TokenClassifierOutput( loss=None, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class MetaQABertModel(BertPreTrainedModel): def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config self.embeddings = MetaQABertEmbeddings(config) # NEW self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ans_sc=None, agent_sc=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ans_sc=ans_sc, agent_sc=agent_sc, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class MetaQABertEmbeddings(nn.Module): """Construct the embeddings from word, position, token_type embeddings, and scores from the QA agents.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.ans_sc_proj = nn.Linear(1, config.hidden_size) self.agent_sc_proj = nn.Linear(1, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), persistent=False, ) def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0, ans_sc=None, agent_sc=None): if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings if ans_sc is not None: ans_sc_emb = self.ans_sc_proj(ans_sc.unsqueeze(2)) embeddings += ans_sc_emb if agent_sc is not None: agent_sc_emb = self.agent_sc_proj(agent_sc.unsqueeze(2)) embeddings += agent_sc_emb embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings