hustcw commited on
Commit
77fb417
·
1 Parent(s): f5d2195

add clap modeling

Browse files
Files changed (3) hide show
  1. clap_modeling.py +229 -0
  2. config.json +3 -1
  3. tokenizer_config.json +3 -0
clap_modeling.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2024 Hustcw
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from typing import Optional
27
+ import torch.nn.functional as F
28
+
29
+ from transformers.models.roformer.modeling_roformer import (
30
+ RoFormerEmbeddings,
31
+ RoFormerModel,
32
+ RoFormerEncoder,
33
+ RoFormerLayer,
34
+ RoFormerAttention,
35
+ RoFormerIntermediate,
36
+ RoFormerOutput,
37
+ RoFormerSelfAttention,
38
+ RoFormerPreTrainedModel
39
+ )
40
+
41
+ from transformers.models.mpnet.modeling_mpnet import MPNetModel
42
+
43
+ from transformers import MPNetTokenizerFast, BatchEncoding
44
+
45
+ class AsmTokenizer(MPNetTokenizerFast):
46
+
47
+ @property
48
+ def pad_token_type_id(self) -> int:
49
+ """
50
+ `int`: Id of the padding token type in the vocabulary.
51
+ """
52
+ return self.pad_token_id
53
+
54
+ def tokenize_function(self, function):
55
+ total_len = 0
56
+ tokenized_functions = {"token": [], "instr": []}
57
+ for key, value in function.items():
58
+ tokens = self.tokenize(value.replace(',', ''), max_length=20, truncation=True, add_special_tokens=False) # set max token for a instruction
59
+ instr_index = "INSTR" + key
60
+ instructions = [instr_index] * len(tokens)
61
+ tokenized_functions["token"].extend(tokens)
62
+ tokenized_functions["instr"].extend(instructions)
63
+ total_len += len(tokens)
64
+ if total_len > self.model_max_length:
65
+ tokenized_functions['token'] = tokenized_functions['token'][:self.model_max_length]
66
+ tokenized_functions['instr'] = tokenized_functions['instr'][:self.model_max_length]
67
+ break
68
+ return tokenized_functions
69
+
70
+ def encode_function(self, function):
71
+ tokenized_functions = self.tokenize_function(function)
72
+ token_ids = self.convert_tokens_to_ids(tokenized_functions["token"])
73
+ instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"])
74
+ return BatchEncoding({
75
+ "input_ids": token_ids,
76
+ "attention_mask": [1] * len(token_ids),
77
+ "token_type_ids": instr_ids,
78
+ })
79
+
80
+ @property
81
+ def vocab_size(self) -> int:
82
+ return len(self.vocab)
83
+
84
+ class JRoFormerEmbeddings(RoFormerEmbeddings):
85
+ """Construct the embeddings from word and token_type embeddings."""
86
+
87
+ def __init__(self, config):
88
+ super().__init__(config)
89
+ self.word_embeddings = nn.Embedding(
90
+ config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
91
+ )
92
+ self.token_type_embeddings = self.word_embeddings
93
+
94
+
95
+ class JRoFormerSelfAttention(RoFormerSelfAttention):
96
+ def __init__(self, config):
97
+ super().__init__(config)
98
+ self.query = nn.Linear(
99
+ config.hidden_size, self.all_head_size, bias=config.use_bias
100
+ )
101
+ self.key = nn.Linear(
102
+ config.hidden_size, self.all_head_size, bias=config.use_bias
103
+ )
104
+ self.value = nn.Linear(
105
+ config.hidden_size, self.all_head_size, bias=config.use_bias
106
+ )
107
+
108
+
109
+ class JRoFormerAttention(RoFormerAttention):
110
+ def __init__(self, config):
111
+ super().__init__(config)
112
+ self.self = JRoFormerSelfAttention(config)
113
+
114
+
115
+ class JRoFormerLayer(RoFormerLayer):
116
+ def __init__(self, config):
117
+ super().__init__(config)
118
+ self.attention = JRoFormerAttention(config)
119
+ self.is_decoder = config.is_decoder
120
+ self.add_cross_attention = config.add_cross_attention
121
+ if self.add_cross_attention:
122
+ if not self.is_decoder:
123
+ raise ValueError(
124
+ f"{self} should be used as a decoder model if cross attention is added"
125
+ )
126
+ self.crossattention = RoFormerAttention(config)
127
+ self.intermediate = RoFormerIntermediate(config)
128
+ self.output = RoFormerOutput(config)
129
+
130
+
131
+ class JRoFormerEncoder(RoFormerEncoder):
132
+ def __init__(self, config):
133
+ super().__init__(config)
134
+ self.layer = nn.ModuleList(
135
+ [JRoFormerLayer(config) for _ in range(config.num_hidden_layers)]
136
+ )
137
+
138
+
139
+ class JRoFormerModel(RoFormerModel):
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+ self.config = config
143
+ self.embeddings = JRoFormerEmbeddings(config)
144
+
145
+ if config.embedding_size != config.hidden_size:
146
+ self.embeddings_project = nn.Linear(
147
+ config.embedding_size, config.hidden_size
148
+ )
149
+
150
+ self.encoder = JRoFormerEncoder(config)
151
+
152
+ # Initialize weights and apply final processing
153
+ self.post_init()
154
+
155
+ class AsmEncoder(RoFormerPreTrainedModel):
156
+ def __init__(self, config):
157
+ super().__init__(config)
158
+ self.config = config
159
+ self.jroformer = JRoFormerModel(config)
160
+ self.projection = nn.Linear(config.hidden_size, config.hidden_size)
161
+
162
+ def forward(
163
+ self,
164
+ input_ids: Optional[torch.LongTensor] = None,
165
+ attention_mask: Optional[torch.FloatTensor] = None,
166
+ token_type_ids: Optional[torch.LongTensor] = None,
167
+ head_mask: Optional[torch.FloatTensor] = None,
168
+ inputs_embeds: Optional[torch.FloatTensor] = None,
169
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
170
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
171
+ output_attentions: Optional[bool] = None,
172
+ output_hidden_states: Optional[bool] = None,
173
+ return_dict: Optional[bool] = None,
174
+ ):
175
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
176
+
177
+ outputs = self.jroformer(
178
+ input_ids,
179
+ attention_mask=attention_mask,
180
+ token_type_ids=token_type_ids,
181
+ head_mask=head_mask,
182
+ inputs_embeds=inputs_embeds,
183
+ encoder_hidden_states=encoder_hidden_states,
184
+ encoder_attention_mask=encoder_attention_mask,
185
+ output_attentions=output_attentions,
186
+ output_hidden_states=output_hidden_states,
187
+ return_dict=return_dict,
188
+ )
189
+
190
+ token_embeddings = outputs[0]
191
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
192
+ asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
193
+ asm_embedding = self.projection(asm_embedding)
194
+ asm_embedding = F.normalize(asm_embedding, p=2, dim=1)
195
+
196
+ return asm_embedding
197
+
198
+ class TextEncoder(MPNetModel):
199
+ def __init__(self, config, add_pooling_layer=True):
200
+ super().__init__(config, add_pooling_layer=add_pooling_layer)
201
+
202
+ def forward(
203
+ self,
204
+ input_ids: Optional[torch.LongTensor] = None,
205
+ attention_mask: Optional[torch.FloatTensor] = None,
206
+ position_ids: Optional[torch.LongTensor] = None,
207
+ head_mask: Optional[torch.FloatTensor] = None,
208
+ inputs_embeds: Optional[torch.FloatTensor] = None,
209
+ output_attentions: Optional[bool] = None,
210
+ output_hidden_states: Optional[bool] = None,
211
+ return_dict: Optional[bool] = None,
212
+ **kwargs,
213
+ ):
214
+ output = super().forward(
215
+ input_ids=input_ids,
216
+ attention_mask=attention_mask,
217
+ position_ids=position_ids,
218
+ head_mask=head_mask,
219
+ inputs_embeds=inputs_embeds,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ return_dict=return_dict,
223
+ **kwargs,
224
+ )
225
+ token_embeddings = output[0]
226
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
227
+ text_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
228
+ text_embedding = F.normalize(text_embedding, p=2, dim=1)
229
+ return text_embedding
config.json CHANGED
@@ -1,8 +1,10 @@
1
  {
2
- "_name_or_path": "./models/asm-encoder",
3
  "architectures": [
4
  "AsmEncoder"
5
  ],
 
 
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "embedding_size": 768,
8
  "hidden_act": "gelu",
 
1
  {
 
2
  "architectures": [
3
  "AsmEncoder"
4
  ],
5
+ "auto_map": {
6
+ "AutoModel": "clap_modeling.AsmEncoder"
7
+ },
8
  "attention_probs_dropout_prob": 0.1,
9
  "embedding_size": 768,
10
  "hidden_act": "gelu",
tokenizer_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
 
 
 
2
  "added_tokens_decoder": {
3
  "0": {
4
  "content": "<s>",
 
1
  {
2
+ "auto_map": {
3
+ "AutoTokenizer": ["clap_modeling.AsmTokenizer", null]
4
+ },
5
  "added_tokens_decoder": {
6
  "0": {
7
  "content": "<s>",