import torch from allennlp.modules.feedforward import FeedForward from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper from higher.patch import monkeypatch as make_functional class ConditionedParameter(torch.nn.Module): def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1): super().__init__() self.parameter_shape = parameter.shape if len(self.parameter_shape) == 2: # condition_dim是从lstm中得到的tensor,然后用linear学习返回到768作为更新的parm_dict self.conditioners = torch.nn.Sequential( torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), torch.nn.Tanh(), torch.nn.utils.weight_norm( torch.nn.Linear( hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1 ) ), ) elif len(self.parameter_shape) == 1: self.conditioners = torch.nn.Sequential( torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), torch.nn.Tanh(), torch.nn.utils.weight_norm( torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1) ), ) else: raise RuntimeError() self.max_scale = max_scale def forward(self, inputs, grad): if len(self.parameter_shape) == 2: ( conditioner_cola, conditioner_rowa, conditioner_colb, conditioner_rowb, conditioner_norm, ) = self.conditioners(inputs).split( [ self.parameter_shape[1], self.parameter_shape[0], self.parameter_shape[1], self.parameter_shape[0], 1, ], dim=-1, ) a = conditioner_rowa.softmax(-1).T @ conditioner_cola b = conditioner_rowb.softmax(-1).T @ conditioner_colb elif len(self.parameter_shape) == 1: a, b, conditioner_norm = self.conditioners(inputs).split( [self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1 ) else: raise RuntimeError() return ( self.max_scale * torch.mean(conditioner_norm.sigmoid(), dim=0).squeeze() # 多条我们直接取mean * (grad * a.squeeze() + b.squeeze()) ) class LSTMConditioner(torch.nn.Module): def __init__( self, vocab_dim=30522, embedding_dim=768, hidden_dim=256, output_dim=1024, embedding_init=None, ): super().__init__() self.embedding = torch.nn.Embedding( num_embeddings=vocab_dim, embedding_dim=embedding_dim, padding_idx=0, _weight=embedding_init, ) self.lstm = PytorchSeq2VecWrapper( torch.nn.LSTM( input_size=embedding_dim, hidden_size=hidden_dim, num_layers=1, bidirectional=True, batch_first=True, ) ) self.linear = FeedForward( input_dim=hidden_dim * 2, num_layers=1, hidden_dims=[output_dim], activations=[torch.nn.Tanh()], ) def forward(self, inputs, masks): return self.linear(self.lstm(self.embedding(inputs), masks)) # 1, 64 class OneShotLearner(torch.nn.Module): def __init__( self, model, vocab_dim=30522, embedding_dim=768, hidden_dim=128, condition_dim=1024, include_set={}, max_scale=1e-3, embedding_init=None, ): super().__init__() self.param2conditioner_map = { n: "{}_conditioner".format(n).replace(".", "_") for n, p in model.named_parameters() if n in include_set } self.conditioners = torch.nn.ModuleDict( { self.param2conditioner_map[n]: ConditionedParameter( p, condition_dim, hidden_dim, max_scale=max_scale, ) for n, p in model.named_parameters() if n in include_set } ) self.condition = LSTMConditioner( vocab_dim, embedding_dim, hidden_dim, condition_dim, embedding_init=embedding_init, ) def forward(self, inputs, masks, grads=None): condition = self.condition(inputs, masks) # LSTM输出condition return { p: self.conditioners[self.param2conditioner_map[p]]( condition, grad=grads[p] if grads else None, ) for p, c in self.param2conditioner_map.items() }