hustcw commited on
Commit
4e6097c
·
1 Parent(s): d2f441b

update modeling

Browse files
Files changed (1) hide show
  1. clap_modeling.py +29 -1
clap_modeling.py CHANGED
@@ -66,7 +66,7 @@ class AsmTokenizer(MPNetTokenizerFast):
66
  tokenized_functions['instr'] = tokenized_functions['instr'][:self.model_max_length]
67
  break
68
  return tokenized_functions
69
-
70
  def encode_function(self, function):
71
  tokenized_functions = self.tokenize_function(function)
72
  token_ids = self.convert_tokens_to_ids(tokenized_functions["token"])
@@ -76,6 +76,34 @@ class AsmTokenizer(MPNetTokenizerFast):
76
  "attention_mask": [1] * len(token_ids),
77
  "token_type_ids": instr_ids,
78
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  @property
81
  def vocab_size(self) -> int:
 
66
  tokenized_functions['instr'] = tokenized_functions['instr'][:self.model_max_length]
67
  break
68
  return tokenized_functions
69
+
70
  def encode_function(self, function):
71
  tokenized_functions = self.tokenize_function(function)
72
  token_ids = self.convert_tokens_to_ids(tokenized_functions["token"])
 
76
  "attention_mask": [1] * len(token_ids),
77
  "token_type_ids": instr_ids,
78
  })
79
+
80
+ def __call__(self, functions, **kwargs):
81
+ if len(functions) == 0:
82
+ return BatchEncoding({
83
+ "input_ids": [],
84
+ "attention_mask": [],
85
+ "token_type_ids": [],
86
+ })
87
+ if not isinstance(functions, list):
88
+ raise ValueError("functions must be a list of dict")
89
+ elif not isinstance(functions[0], dict):
90
+ raise ValueError("functions must be a list of dict")
91
+ else:
92
+ batch_encode_result = {
93
+ "input_ids": [],
94
+ "attention_mask": [],
95
+ "token_type_ids": [],
96
+ }
97
+ for function in functions:
98
+ tokenized_functions = self.tokenize_function(function)
99
+ token_ids = self.convert_tokens_to_ids(tokenized_functions["token"])
100
+ instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"])
101
+ attention_mask = [1] * len(token_ids)
102
+ batch_encode_result["input_ids"].append(token_ids)
103
+ batch_encode_result["attention_mask"].append(attention_mask)
104
+ batch_encode_result["token_type_ids"].append(instr_ids)
105
+ batch_encoding = BatchEncoding(batch_encode_result)
106
+ return self.pad(batch_encoding, **kwargs)
107
 
108
  @property
109
  def vocab_size(self) -> int: