Upload model
Browse files- config.json +4 -0
- config.py +33 -0
- model.py +158 -0
config.json
CHANGED
@@ -3,6 +3,10 @@
|
|
3 |
"architectures": [
|
4 |
"DenoSentModel"
|
5 |
],
|
|
|
|
|
|
|
|
|
6 |
"contrastive_temp": 0.05,
|
7 |
"contrastive_weight": 5.0,
|
8 |
"decoder_noise_dropout": 0.825,
|
|
|
3 |
"architectures": [
|
4 |
"DenoSentModel"
|
5 |
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "config.DenoSentConfig",
|
8 |
+
"AutoModel": "model.DenoSentModel"
|
9 |
+
},
|
10 |
"contrastive_temp": 0.05,
|
11 |
"contrastive_weight": 5.0,
|
12 |
"decoder_noise_dropout": 0.825,
|
config.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
class DenoSentConfig(PretrainedConfig):
|
5 |
+
def __init__(self,
|
6 |
+
encoder_name_or_path:Optional[str]=None,
|
7 |
+
hidden_size:Optional[int]=768,
|
8 |
+
max_length:Optional[int]=32,
|
9 |
+
decoder_num_heads:Optional[int]=1,
|
10 |
+
decoder_num_layers:Optional[int]=16,
|
11 |
+
decoder_noise_dropout:Optional[float]=0.825,
|
12 |
+
pooler:Optional[str]='mask',
|
13 |
+
do_contrastive:Optional[bool]=False,
|
14 |
+
do_generative:Optional[bool]=False,
|
15 |
+
prompt_format:Optional[str]='[X] means [MASK]',
|
16 |
+
contrastive_weight:Optional[float]=1.0,
|
17 |
+
generative_weight:Optional[float]=1.0,
|
18 |
+
contrastive_temp: Optional[float]=0.05,
|
19 |
+
**kwargs):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.encoder_name_or_path = encoder_name_or_path
|
22 |
+
self.hidden_size = hidden_size
|
23 |
+
self.max_length = max_length
|
24 |
+
self.decoder_num_heads = decoder_num_heads
|
25 |
+
self.decoder_num_layers = decoder_num_layers
|
26 |
+
self.decoder_noise_dropout = decoder_noise_dropout
|
27 |
+
self.pooler = pooler
|
28 |
+
self.do_contrastive = do_contrastive
|
29 |
+
self.do_generative = do_generative
|
30 |
+
self.prompt_format = prompt_format
|
31 |
+
self.contrastive_weight = contrastive_weight
|
32 |
+
self.generative_weight = generative_weight
|
33 |
+
self.contrastive_temp = contrastive_temp
|
model.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, BertForMaskedLM
|
2 |
+
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
3 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
4 |
+
from transformers import PreTrainedModel
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import CrossEntropyLoss, TransformerDecoder, TransformerDecoderLayer
|
9 |
+
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import wandb
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
class DenoSentModel(PreTrainedModel):
|
16 |
+
def __init__(self, config):
|
17 |
+
super().__init__(config)
|
18 |
+
self.pooler = config.pooler
|
19 |
+
self.sent_embedding_projector = nn.Linear(config.hidden_size, config.hidden_size)
|
20 |
+
self.decoder = TransformerDecoder(TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.decoder_num_heads, batch_first=True, dropout=0.1), num_layers=config.decoder_num_layers)
|
21 |
+
self.decoder_noise_dropout = nn.Dropout(config.decoder_noise_dropout)
|
22 |
+
self.sim = nn.CosineSimilarity(dim=-1)
|
23 |
+
self.init_weights()
|
24 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.encoder_name_or_path)
|
25 |
+
self.encoder = BertForMaskedLM.from_pretrained(config.encoder_name_or_path)
|
26 |
+
self.prediction_head = self.encoder.cls
|
27 |
+
self.encoder = self.encoder.bert
|
28 |
+
self.post_init()
|
29 |
+
|
30 |
+
def _init_weights(self, module):
|
31 |
+
"""Initialize the weights"""
|
32 |
+
if isinstance(module, nn.Linear):
|
33 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
34 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
35 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
36 |
+
if module.bias is not None:
|
37 |
+
module.bias.data.zero_()
|
38 |
+
elif isinstance(module, nn.Embedding):
|
39 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
40 |
+
if module.padding_idx is not None:
|
41 |
+
module.weight.data[module.padding_idx].zero_()
|
42 |
+
elif isinstance(module, nn.LayerNorm):
|
43 |
+
module.bias.data.zero_()
|
44 |
+
module.weight.data.fill_(1.0)
|
45 |
+
|
46 |
+
def encode(self, sentences, batch_size=32, **kwargs):
|
47 |
+
""" Returns a list of embeddings for the given sentences.
|
48 |
+
Args:
|
49 |
+
sentences (`List[str]`): List of sentences to encode
|
50 |
+
batch_size (`int`): Batch size for the encoding
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
|
54 |
+
"""
|
55 |
+
self.eval()
|
56 |
+
all_embeddings = []
|
57 |
+
length_sorted_idx = np.argsort([len(sen) for sen in sentences])
|
58 |
+
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
59 |
+
if self.config.pooler == 'mask':
|
60 |
+
prompt_length = len(self.tokenizer(self.config.prompt_format, add_special_tokens=False)['input_ids'])
|
61 |
+
sentences_sorted = self.tokenizer.batch_decode(self.tokenizer(sentences_sorted, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt').input_ids, skip_special_tokens=True)
|
62 |
+
sentences_sorted = [self.config.prompt_format.replace('[X]', s).replace('[MASK]', self.tokenizer.mask_token) for s in sentences_sorted]
|
63 |
+
for start_index in range(0, len(sentences), batch_size):
|
64 |
+
sentences_batch = sentences_sorted[start_index:start_index+batch_size]
|
65 |
+
inputs = self.tokenizer(sentences_batch, padding='max_length', truncation=True, return_tensors="pt", max_length=self.config.max_length+prompt_length)
|
66 |
+
inputs = {k: v.to(self.device) for k,v in inputs.items()}
|
67 |
+
with torch.no_grad():
|
68 |
+
encoder_outputs = self.encoder(**inputs, output_hidden_states=True, output_attentions=True, return_dict=True)
|
69 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
70 |
+
if self.config.pooler == 'cls':
|
71 |
+
embeddings = last_hidden_state[:, 0, :]
|
72 |
+
elif self.config.pooler == 'mean':
|
73 |
+
embeddings = (last_hidden_state * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1)
|
74 |
+
elif self.pooler == 'mask':
|
75 |
+
embeddings = last_hidden_state[inputs['input_ids'] == self.tokenizer.mask_token_id]
|
76 |
+
else:
|
77 |
+
raise NotImplementedError()
|
78 |
+
all_embeddings.extend(embeddings.cpu().numpy())
|
79 |
+
all_embeddings = torch.tensor(np.array([all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]))
|
80 |
+
return all_embeddings
|
81 |
+
|
82 |
+
def forward(
|
83 |
+
self,
|
84 |
+
input_ids: Optional[torch.LongTensor] = None,
|
85 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
86 |
+
positive_input_ids: Optional[torch.LongTensor] = None,
|
87 |
+
positive_attention_mask: Optional[torch.LongTensor] = None,
|
88 |
+
negative_input_ids: Optional[torch.LongTensor] = None,
|
89 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
90 |
+
global_step: Optional[int] = None,
|
91 |
+
max_steps: Optional[int] = None,
|
92 |
+
):
|
93 |
+
batch_size = input_ids.size(0)
|
94 |
+
if negative_input_ids is not None:
|
95 |
+
encoder_input_ids = torch.cat([input_ids, positive_input_ids, negative_input_ids], dim=0).to(self.device)
|
96 |
+
encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask, negative_attention_mask], dim=0).to(self.device)
|
97 |
+
elif positive_input_ids is not None:
|
98 |
+
encoder_input_ids = torch.cat([input_ids, positive_input_ids], dim=0).to(self.device)
|
99 |
+
encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask], dim=0).to(self.device)
|
100 |
+
elif self.config.do_contrastive:
|
101 |
+
encoder_input_ids = torch.cat([input_ids, input_ids], dim=0).to(self.device)
|
102 |
+
encoder_attention_mask = torch.cat([attention_mask, attention_mask], dim=0).to(self.device)
|
103 |
+
elif self.config.do_generative and not self.config.do_contrastive:
|
104 |
+
encoder_input_ids = input_ids.to(self.device)
|
105 |
+
encoder_attention_mask = attention_mask.to(self.device)
|
106 |
+
else:
|
107 |
+
raise NotImplementedError()
|
108 |
+
encoder_outputs = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)
|
109 |
+
if self.pooler == 'cls':
|
110 |
+
sent_embedding = encoder_outputs.last_hidden_state[:, 0, :]
|
111 |
+
elif self.pooler == 'mean':
|
112 |
+
sent_embedding = ((encoder_outputs.last_hidden_state * encoder_attention_mask.unsqueeze(-1)).sum(1) / encoder_attention_mask.sum(-1).unsqueeze(-1))
|
113 |
+
elif self.pooler == 'mask':
|
114 |
+
sent_embedding = encoder_outputs.last_hidden_state[encoder_input_ids == self.tokenizer.mask_token_id]
|
115 |
+
else:
|
116 |
+
raise NotImplementedError()
|
117 |
+
sent_embedding = sent_embedding.unsqueeze(1)
|
118 |
+
sent_embedding = self.sent_embedding_projector(sent_embedding)
|
119 |
+
|
120 |
+
if self.config.do_generative:
|
121 |
+
if positive_input_ids is not None:
|
122 |
+
tgt = encoder_outputs.hidden_states[0][batch_size:2*batch_size].detach()
|
123 |
+
tgt_key_padding_mask = (positive_input_ids == self.tokenizer.pad_token_id)
|
124 |
+
labels = positive_input_ids
|
125 |
+
else:
|
126 |
+
tgt = encoder_outputs.hidden_states[0][:batch_size].detach()
|
127 |
+
tgt_key_padding_mask = (input_ids == self.tokenizer.pad_token_id)
|
128 |
+
labels = input_ids
|
129 |
+
tgt = self.decoder_noise_dropout(tgt)
|
130 |
+
decoder_outputs = self.decoder(tgt=tgt, memory=sent_embedding[:batch_size], tgt_mask=None, tgt_key_padding_mask=tgt_key_padding_mask)
|
131 |
+
logits = self.prediction_head(decoder_outputs)
|
132 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
|
133 |
+
generative_loss = loss_fct(logits.view(-1, self.encoder.config.vocab_size), labels.view(-1))
|
134 |
+
wandb.log({'train/generative_loss': generative_loss})
|
135 |
+
|
136 |
+
if self.config.do_contrastive:
|
137 |
+
positive_sim = self.sim(sent_embedding[:batch_size], sent_embedding[batch_size:2*batch_size].transpose(0, 1))
|
138 |
+
cos_sim = positive_sim
|
139 |
+
if negative_attention_mask is not None:
|
140 |
+
negative_sim = self.sim(sent_embedding[:batch_size], sent_embedding[2*batch_size:].transpose(0, 1))
|
141 |
+
cos_sim = torch.cat([positive_sim, negative_sim], dim=1)
|
142 |
+
cos_sim = cos_sim / self.config.contrastive_temp
|
143 |
+
contrastive_labels = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
144 |
+
contrastive_loss = nn.CrossEntropyLoss()(cos_sim, contrastive_labels)
|
145 |
+
wandb.log({'train/contrastive_loss': contrastive_loss.item()})
|
146 |
+
logits = None
|
147 |
+
loss = 0
|
148 |
+
if self.config.do_contrastive:
|
149 |
+
loss += self.config.contrastive_weight * contrastive_loss
|
150 |
+
if self.config.do_generative:
|
151 |
+
loss += self.config.generative_weight * generative_loss
|
152 |
+
wandb.log({'train/loss': loss})
|
153 |
+
return TokenClassifierOutput(
|
154 |
+
loss=loss,
|
155 |
+
logits=logits,
|
156 |
+
hidden_states=encoder_outputs.hidden_states,
|
157 |
+
attentions=encoder_outputs.attentions,
|
158 |
+
)
|