Jean Garcia-Gathright commited on
Commit
a02c788
·
1 Parent(s): 4150cb0

added ernie files

Browse files
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
- import transformers
3
- import tensorflow
4
 
5
  def greet(name):
6
  return "Hello " + name + "!!"
 
1
  import gradio as gr
2
+ from ernie.ernie import SentenceClassifier
3
+ from ernie import helper
4
 
5
  def greet(name):
6
  return "Hello " + name + "!!"
app.py~ ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()
ernie/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from .ernie import * # noqa: F401, F403
5
+ from tensorflow.python.client import device_lib
6
+ import logging
7
+
8
+ __version__ = '1.0.1'
9
+
10
+ logging.getLogger().setLevel(logging.WARNING)
11
+ logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
12
+ logging.basicConfig(
13
+ format='%(asctime)-15s [%(levelname)s] %(message)s',
14
+ datefmt='%Y-%m-%d %H:%M:%S'
15
+ )
16
+
17
+
18
+ def _get_cpu_name():
19
+ import cpuinfo
20
+ cpu_info = cpuinfo.get_cpu_info()
21
+ cpu_name = f"{cpu_info['brand_raw']}, {cpu_info['count']} vCores"
22
+ return cpu_name
23
+
24
+
25
+ def _get_gpu_name():
26
+ gpu_name = \
27
+ device_lib\
28
+ .list_local_devices()[3]\
29
+ .physical_device_desc\
30
+ .split(',')[1]\
31
+ .split('name:')[1]\
32
+ .strip()
33
+ return gpu_name
34
+
35
+
36
+ device_name = _get_cpu_name()
37
+ device_type = 'CPU'
38
+
39
+ try:
40
+ device_name = _get_gpu_name()
41
+ device_type = 'GPU'
42
+ except IndexError:
43
+ # Detect TPU
44
+ pass
45
+
46
+ logging.info(f'ernie v{__version__}')
47
+ logging.info(f'target device: [{device_type}] {device_name}\n')
ernie/aggregation_strategies.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from statistics import mean
5
+
6
+
7
+ class AggregationStrategy:
8
+ def __init__(
9
+ self,
10
+ method,
11
+ max_items=None,
12
+ top_items=True,
13
+ sorting_class_index=1
14
+ ):
15
+ self.method = method
16
+ self.max_items = max_items
17
+ self.top_items = top_items
18
+ self.sorting_class_index = sorting_class_index
19
+
20
+ def aggregate(self, softmax_tuples):
21
+ softmax_dicts = []
22
+ for softmax_tuple in softmax_tuples:
23
+ softmax_dict = {}
24
+ for i, probability in enumerate(softmax_tuple):
25
+ softmax_dict[i] = probability
26
+ softmax_dicts.append(softmax_dict)
27
+
28
+ if self.max_items is not None:
29
+ softmax_dicts = sorted(
30
+ softmax_dicts,
31
+ key=lambda x: x[self.sorting_class_index],
32
+ reverse=self.top_items
33
+ )
34
+ if self.max_items < len(softmax_dicts):
35
+ softmax_dicts = softmax_dicts[:self.max_items]
36
+
37
+ softmax_list = []
38
+ for key in softmax_dicts[0].keys():
39
+ softmax_list.append(self.method(
40
+ [probabilities[key] for probabilities in softmax_dicts]))
41
+ softmax_tuple = tuple(softmax_list)
42
+ return softmax_tuple
43
+
44
+
45
+ class AggregationStrategies:
46
+ Mean = AggregationStrategy(method=mean)
47
+ MeanTopFiveBinaryClassification = AggregationStrategy(
48
+ method=mean,
49
+ max_items=5,
50
+ top_items=True,
51
+ sorting_class_index=1
52
+ )
53
+ MeanTopTenBinaryClassification = AggregationStrategy(
54
+ method=mean,
55
+ max_items=10,
56
+ top_items=True,
57
+ sorting_class_index=1
58
+ )
59
+ MeanTopFifteenBinaryClassification = AggregationStrategy(
60
+ method=mean,
61
+ max_items=15,
62
+ top_items=True,
63
+ sorting_class_index=1
64
+ )
65
+ MeanTopTwentyBinaryClassification = AggregationStrategy(
66
+ method=mean,
67
+ max_items=20,
68
+ top_items=True,
69
+ sorting_class_index=1
70
+ )
ernie/ernie.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModel,
9
+ AutoConfig,
10
+ TFAutoModelForSequenceClassification,
11
+ )
12
+ from tensorflow import keras
13
+ from sklearn.model_selection import train_test_split
14
+ import logging
15
+ import time
16
+ from .models import Models, ModelsByFamily # noqa: F401
17
+ from .split_strategies import ( # noqa: F401
18
+ SplitStrategy,
19
+ SplitStrategies,
20
+ RegexExpressions
21
+ )
22
+ from .aggregation_strategies import ( # noqa: F401
23
+ AggregationStrategy,
24
+ AggregationStrategies
25
+ )
26
+ from .helper import (
27
+ get_features,
28
+ softmax,
29
+ remove_dir,
30
+ make_dir,
31
+ copy_dir
32
+ )
33
+
34
+ AUTOSAVE_PATH = './ernie-autosave/'
35
+
36
+
37
+ def clean_autosave():
38
+ remove_dir(AUTOSAVE_PATH)
39
+
40
+
41
+ class SentenceClassifier:
42
+ def __init__(self,
43
+ model_name=Models.BertBaseUncased,
44
+ model_path=None,
45
+ max_length=64,
46
+ labels_no=2,
47
+ tokenizer_kwargs=None,
48
+ model_kwargs=None):
49
+ self._loaded_data = False
50
+ self._model_path = None
51
+
52
+ if model_kwargs is None:
53
+ model_kwargs = {}
54
+ model_kwargs['num_labels'] = labels_no
55
+
56
+ if tokenizer_kwargs is None:
57
+ tokenizer_kwargs = {}
58
+ tokenizer_kwargs['max_len'] = max_length
59
+
60
+ if model_path is not None:
61
+ self._load_local_model(model_path)
62
+ else:
63
+ self._load_remote_model(model_name, tokenizer_kwargs, model_kwargs)
64
+
65
+ @property
66
+ def model(self):
67
+ return self._model
68
+
69
+ @property
70
+ def tokenizer(self):
71
+ return self._tokenizer
72
+
73
+ def load_dataset(self,
74
+ dataframe=None,
75
+ validation_split=0.1,
76
+ random_state=None,
77
+ stratify=None,
78
+ csv_path=None,
79
+ read_csv_kwargs=None):
80
+
81
+ if dataframe is None and csv_path is None:
82
+ raise ValueError
83
+
84
+ if csv_path is not None:
85
+ dataframe = pd.read_csv(csv_path, **read_csv_kwargs)
86
+
87
+ sentences = list(dataframe[dataframe.columns[0]])
88
+ labels = dataframe[dataframe.columns[1]].values
89
+
90
+ (
91
+ training_sentences,
92
+ validation_sentences,
93
+ training_labels,
94
+ validation_labels
95
+ ) = train_test_split(
96
+ sentences,
97
+ labels,
98
+ test_size=validation_split,
99
+ shuffle=True,
100
+ random_state=random_state,
101
+ stratify=stratify
102
+ )
103
+
104
+ self._training_features = get_features(
105
+ self._tokenizer, training_sentences, training_labels)
106
+
107
+ self._training_size = len(training_sentences)
108
+
109
+ self._validation_features = get_features(
110
+ self._tokenizer,
111
+ validation_sentences,
112
+ validation_labels
113
+ )
114
+ self._validation_split = len(validation_sentences)
115
+
116
+ logging.info(f'training_size: {self._training_size}')
117
+ logging.info(f'validation_split: {self._validation_split}')
118
+
119
+ self._loaded_data = True
120
+
121
+ def fine_tune(self,
122
+ epochs=4,
123
+ learning_rate=2e-5,
124
+ epsilon=1e-8,
125
+ clipnorm=1.0,
126
+ optimizer_function=keras.optimizers.Adam,
127
+ optimizer_kwargs=None,
128
+ loss_function=keras.losses.SparseCategoricalCrossentropy,
129
+ loss_kwargs=None,
130
+ accuracy_function=keras.metrics.SparseCategoricalAccuracy,
131
+ accuracy_kwargs=None,
132
+ training_batch_size=32,
133
+ validation_batch_size=64,
134
+ **kwargs):
135
+ if not self._loaded_data:
136
+ raise Exception('Data has not been loaded.')
137
+
138
+ if optimizer_kwargs is None:
139
+ optimizer_kwargs = {
140
+ 'learning_rate': learning_rate,
141
+ 'epsilon': epsilon,
142
+ 'clipnorm': clipnorm
143
+ }
144
+ optimizer = optimizer_function(**optimizer_kwargs)
145
+
146
+ if loss_kwargs is None:
147
+ loss_kwargs = {'from_logits': True}
148
+ loss = loss_function(**loss_kwargs)
149
+
150
+ if accuracy_kwargs is None:
151
+ accuracy_kwargs = {'name': 'accuracy'}
152
+ accuracy = accuracy_function(**accuracy_kwargs)
153
+
154
+ self._model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy])
155
+
156
+ training_features = self._training_features.shuffle(
157
+ self._training_size).batch(training_batch_size).repeat(-1)
158
+ validation_features = self._validation_features.batch(
159
+ validation_batch_size)
160
+
161
+ training_steps = self._training_size // training_batch_size
162
+ if training_steps == 0:
163
+ training_steps = self._training_size
164
+ logging.info(f'training_steps: {training_steps}')
165
+
166
+ validation_steps = self._validation_split // validation_batch_size
167
+ if validation_steps == 0:
168
+ validation_steps = self._validation_split
169
+ logging.info(f'validation_steps: {validation_steps}')
170
+
171
+ for i in range(epochs):
172
+ self._model.fit(training_features,
173
+ epochs=1,
174
+ validation_data=validation_features,
175
+ steps_per_epoch=training_steps,
176
+ validation_steps=validation_steps,
177
+ **kwargs)
178
+
179
+ # The fine-tuned model does not have the same input interface
180
+ # after being exported and loaded again.
181
+ self._reload_model()
182
+
183
+ def predict_one(
184
+ self,
185
+ text,
186
+ split_strategy=None,
187
+ aggregation_strategy=None
188
+ ):
189
+ return next(
190
+ self.predict([text],
191
+ batch_size=1,
192
+ split_strategy=split_strategy,
193
+ aggregation_strategy=aggregation_strategy))
194
+
195
+ def predict(
196
+ self,
197
+ texts,
198
+ batch_size=32,
199
+ split_strategy=None,
200
+ aggregation_strategy=None
201
+ ):
202
+ if split_strategy is None:
203
+ yield from self._predict_batch(texts, batch_size)
204
+
205
+ else:
206
+ if aggregation_strategy is None:
207
+ aggregation_strategy = AggregationStrategies.Mean
208
+
209
+ split_indexes = [0]
210
+ sentences = []
211
+ for text in texts:
212
+ new_sentences = split_strategy.split(text, self.tokenizer)
213
+ if not new_sentences:
214
+ continue
215
+ split_indexes.append(split_indexes[-1] + len(new_sentences))
216
+ sentences.extend(new_sentences)
217
+
218
+ predictions = list(self._predict_batch(sentences, batch_size))
219
+ for i, split_index in enumerate(split_indexes[:-1]):
220
+ stop_index = split_indexes[i + 1]
221
+ yield aggregation_strategy.aggregate(
222
+ predictions[split_index:stop_index]
223
+ )
224
+
225
+ def dump(self, path):
226
+ if self._model_path:
227
+ copy_dir(self._model_path, path)
228
+ else:
229
+ self._dump(path)
230
+
231
+ def _dump(self, path):
232
+ make_dir(path)
233
+ make_dir(path + '/tokenizer')
234
+ self._model.save_pretrained(path)
235
+ self._tokenizer.save_pretrained(path + '/tokenizer')
236
+ self._config.save_pretrained(path + '/tokenizer')
237
+
238
+ def _predict_batch(self, sentences: list, batch_size: int):
239
+ sentences_number = len(sentences)
240
+ if batch_size > sentences_number:
241
+ batch_size = sentences_number
242
+
243
+ for i in range(0, sentences_number, batch_size):
244
+ input_ids_list = []
245
+ attention_mask_list = []
246
+
247
+ stop_index = i + batch_size
248
+ stop_index = stop_index if stop_index < sentences_number \
249
+ else sentences_number
250
+
251
+ for j in range(i, stop_index):
252
+ features = self._tokenizer.encode_plus(
253
+ sentences[j],
254
+ add_special_tokens=True,
255
+ max_length=self._tokenizer.model_max_length
256
+ )
257
+ input_ids, _, attention_mask = (
258
+ features['input_ids'],
259
+ features['token_type_ids'],
260
+ features['attention_mask']
261
+ )
262
+
263
+ input_ids = self._list_to_padded_array(features['input_ids'])
264
+ attention_mask = self._list_to_padded_array(
265
+ features['attention_mask'])
266
+
267
+ input_ids_list.append(input_ids)
268
+ attention_mask_list.append(attention_mask)
269
+
270
+ input_dict = {
271
+ 'input_ids': np.array(input_ids_list),
272
+ 'attention_mask': np.array(attention_mask_list)
273
+ }
274
+ logit_predictions = self._model.predict_on_batch(input_dict)
275
+ yield from (
276
+ [softmax(logit_prediction)
277
+ for logit_prediction in logit_predictions[0]]
278
+ )
279
+
280
+ def _list_to_padded_array(self, items):
281
+ array = np.array(items)
282
+ padded_array = np.zeros(self._tokenizer.model_max_length, dtype=np.int)
283
+ padded_array[:array.shape[0]] = array
284
+ return padded_array
285
+
286
+ def _get_temporary_path(self, name=''):
287
+ return f'{AUTOSAVE_PATH}{name}/{int(round(time.time() * 1000))}'
288
+
289
+ def _reload_model(self):
290
+ self._model_path = self._get_temporary_path(
291
+ name=self._get_model_family())
292
+ self._dump(self._model_path)
293
+ self._load_local_model(self._model_path)
294
+
295
+ def _load_local_model(self, model_path):
296
+ try:
297
+ self._tokenizer = AutoTokenizer.from_pretrained(
298
+ model_path + '/tokenizer')
299
+ self._config = AutoConfig.from_pretrained(
300
+ model_path + '/tokenizer')
301
+
302
+ # Old models didn't use to have a tokenizer folder
303
+ except OSError:
304
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path)
305
+ self._config = AutoConfig.from_pretrained(model_path)
306
+ self._model = TFAutoModelForSequenceClassification.from_pretrained(
307
+ model_path,
308
+ from_pt=False
309
+ )
310
+
311
+ def _get_model_family(self):
312
+ model_family = ''.join(self._model.name[2:].split('_')[:2])
313
+ return model_family
314
+
315
+ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs):
316
+ do_lower_case = False
317
+ if 'uncased' in model_name.lower():
318
+ do_lower_case = True
319
+ tokenizer_kwargs.update({'do_lower_case': do_lower_case})
320
+
321
+ self._tokenizer = AutoTokenizer.from_pretrained(
322
+ model_name, **tokenizer_kwargs)
323
+ self._config = AutoConfig.from_pretrained(model_name)
324
+
325
+ temporary_path = self._get_temporary_path()
326
+ make_dir(temporary_path)
327
+
328
+ # TensorFlow model
329
+ try:
330
+ self._model = TFAutoModelForSequenceClassification.from_pretrained(
331
+ model_name,
332
+ from_pt=False
333
+ )
334
+
335
+ # PyTorch model
336
+ except TypeError:
337
+ try:
338
+ self._model = \
339
+ TFAutoModelForSequenceClassification.from_pretrained(
340
+ model_name,
341
+ from_pt=True
342
+ )
343
+
344
+ # Loading a TF model from a PyTorch checkpoint is not supported
345
+ # when using a model identifier name
346
+ except OSError:
347
+ model = AutoModel.from_pretrained(model_name)
348
+ model.save_pretrained(temporary_path)
349
+ self._model = \
350
+ TFAutoModelForSequenceClassification.from_pretrained(
351
+ temporary_path,
352
+ from_pt=True
353
+ )
354
+
355
+ # Clean the model's last layer if the provided properties are different
356
+ clean_last_layer = False
357
+ for key, value in model_kwargs.items():
358
+ if not hasattr(self._model.config, key):
359
+ clean_last_layer = True
360
+ break
361
+
362
+ if getattr(self._model.config, key) != value:
363
+ clean_last_layer = True
364
+ break
365
+
366
+ if clean_last_layer:
367
+ try:
368
+ getattr(self._model, self._get_model_family()
369
+ ).save_pretrained(temporary_path)
370
+ self._model = self._model.__class__.from_pretrained(
371
+ temporary_path,
372
+ from_pt=False,
373
+ **model_kwargs
374
+ )
375
+
376
+ # The model is itself the main layer
377
+ except AttributeError:
378
+ # TensorFlow model
379
+ try:
380
+ self._model = self._model.__class__.from_pretrained(
381
+ model_name,
382
+ from_pt=False,
383
+ **model_kwargs
384
+ )
385
+
386
+ # PyTorch Model
387
+ except (OSError, TypeError):
388
+ model = AutoModel.from_pretrained(model_name)
389
+ model.save_pretrained(temporary_path)
390
+ self._model = self._model.__class__.from_pretrained(
391
+ temporary_path,
392
+ from_pt=True,
393
+ **model_kwargs
394
+ )
395
+
396
+ remove_dir(temporary_path)
397
+ assert self._tokenizer and self._model
ernie/helper.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from tensorflow import data, TensorShape, int64, int32
5
+ from math import exp
6
+ from os import makedirs
7
+ from shutil import rmtree, move, copytree
8
+ from huggingface_hub import hf_hub_download
9
+ import os
10
+
11
+
12
+ def get_features(tokenizer, sentences, labels):
13
+ features = []
14
+ for i, sentence in enumerate(sentences):
15
+ inputs = tokenizer.encode_plus(
16
+ sentence,
17
+ add_special_tokens=True,
18
+ max_length=tokenizer.model_max_length
19
+ )
20
+ input_ids, token_type_ids = \
21
+ inputs['input_ids'], inputs['token_type_ids']
22
+ padding_length = tokenizer.model_max_length - len(input_ids)
23
+
24
+ if tokenizer.padding_side == 'right':
25
+ attention_mask = [1] * len(input_ids) + [0] * padding_length
26
+ input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
27
+ token_type_ids = token_type_ids + \
28
+ [tokenizer.pad_token_type_id] * padding_length
29
+ else:
30
+ attention_mask = [0] * padding_length + [1] * len(input_ids)
31
+ input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
32
+ token_type_ids = \
33
+ [tokenizer.pad_token_type_id] * padding_length + token_type_ids
34
+
35
+ assert tokenizer.model_max_length \
36
+ == len(attention_mask) \
37
+ == len(input_ids) \
38
+ == len(token_type_ids)
39
+
40
+ feature = {
41
+ 'input_ids': input_ids,
42
+ 'attention_mask': attention_mask,
43
+ 'token_type_ids': token_type_ids,
44
+ 'label': int(labels[i])
45
+ }
46
+
47
+ features.append(feature)
48
+
49
+ def gen():
50
+ for feature in features:
51
+ yield (
52
+ {
53
+ 'input_ids': feature['input_ids'],
54
+ 'attention_mask': feature['attention_mask'],
55
+ 'token_type_ids': feature['token_type_ids'],
56
+ },
57
+ feature['label'],
58
+ )
59
+
60
+ dataset = data.Dataset.from_generator(
61
+ gen,
62
+ ({
63
+ 'input_ids': int32,
64
+ 'attention_mask': int32,
65
+ 'token_type_ids': int32
66
+ }, int64),
67
+ (
68
+ {
69
+ 'input_ids': TensorShape([None]),
70
+ 'attention_mask': TensorShape([None]),
71
+ 'token_type_ids': TensorShape([None]),
72
+ },
73
+ TensorShape([]),
74
+ ),
75
+ )
76
+
77
+ return dataset
78
+
79
+
80
+ def softmax(values):
81
+ exps = [exp(value) for value in values]
82
+ exps_sum = sum(exp_value for exp_value in exps)
83
+ return tuple(map(lambda x: x / exps_sum, exps))
84
+
85
+
86
+ def make_dir(path):
87
+ try:
88
+ makedirs(path)
89
+ except FileExistsError:
90
+ pass
91
+
92
+
93
+ def remove_dir(path):
94
+ rmtree(path)
95
+
96
+
97
+ def copy_dir(source_path, target_path):
98
+ copytree(source_path, target_path)
99
+
100
+
101
+ def move_dir(source_path, target_path):
102
+ move(source_path, target_path)
103
+
104
+ def download_from_hub(repo_id, filename, revision=None, cache_dir=None):
105
+ try:
106
+ hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir)
107
+ except Exception as exp:
108
+ raise exp
109
+
110
+
111
+ if cache_dir is not None:
112
+
113
+ files = os.listdir(cache_dir)
114
+
115
+ for f in files:
116
+ if '.lock' in f:
117
+ name = f[0:-5]
118
+
119
+ os.rename(cache_dir+name, cache_dir+filename)
120
+ os.remove(cache_dir+name+'.lock')
121
+ os.remove(cache_dir+name+'.json')
ernie/models.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ class Models:
6
+ BertBaseUncased = 'bert-base-uncased'
7
+ BertBaseCased = 'bert-base-cased'
8
+ BertLargeUncased = 'bert-large-uncased'
9
+ BertLargeCased = 'bert-large-cased'
10
+
11
+ RobertaBaseCased = 'roberta-base'
12
+ RobertaLargeCased = 'roberta-large'
13
+
14
+ XLNetBaseCased = 'xlnet-base-cased'
15
+ XLNetLargeCased = 'xlnet-large-cased'
16
+
17
+ DistilBertBaseUncased = 'distilbert-base-uncased'
18
+ DistilBertBaseMultilingualCased = 'distilbert-base-multilingual-cased'
19
+
20
+ AlbertBaseCased = 'albert-base-v1'
21
+ AlbertLargeCased = 'albert-large-v1'
22
+ AlbertXLargeCased = 'albert-xlarge-v1'
23
+ AlbertXXLargeCased = 'albert-xxlarge-v1'
24
+
25
+ AlbertBaseCased2 = 'albert-base-v2'
26
+ AlbertLargeCased2 = 'albert-large-v2'
27
+ AlbertXLargeCased2 = 'albert-xlarge-v2'
28
+ AlbertXXLargeCased2 = 'albert-xxlarge-v2'
29
+
30
+
31
+ class ModelsByFamily:
32
+ Bert = set([Models.BertBaseUncased, Models.BertBaseCased,
33
+ Models.BertLargeUncased, Models.BertLargeCased])
34
+ Roberta = set([Models.RobertaBaseCased, Models.RobertaLargeCased])
35
+ XLNet = set([Models.XLNetBaseCased, Models.XLNetLargeCased])
36
+ DistilBert = set([Models.DistilBertBaseUncased,
37
+ Models.DistilBertBaseMultilingualCased])
38
+ Albert = set([
39
+ Models.AlbertBaseCased,
40
+ Models.AlbertLargeCased,
41
+ Models.AlbertXLargeCased,
42
+ Models.AlbertXXLargeCased,
43
+ Models.AlbertBaseCased2,
44
+ Models.AlbertLargeCased2,
45
+ Models.AlbertXLargeCased2,
46
+ Models.AlbertXXLargeCased2
47
+ ])
48
+ Supported = set([
49
+ getattr(Models, model_type) for model_type
50
+ in filter(lambda x: x[:2] != '__', Models.__dict__.keys())
51
+ ])
ernie/split_strategies.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import re
5
+
6
+
7
+ class RegexExpressions:
8
+ split_by_dot = re.compile(r'[^.]+(?:\.\s*)?')
9
+ split_by_semicolon = re.compile(r'[^;]+(?:\;\s*)?')
10
+ split_by_colon = re.compile(r'[^:]+(?:\:\s*)?')
11
+ split_by_comma = re.compile(r'[^,]+(?:\,\s*)?')
12
+
13
+ url = re.compile(
14
+ r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
15
+ r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
16
+ )
17
+ domain = re.compile(r'\w+\.\w+')
18
+
19
+
20
+ class SplitStrategy:
21
+ def __init__(
22
+ self,
23
+ split_patterns,
24
+ remove_patterns=None,
25
+ group_splits=True,
26
+ remove_too_short_groups=True
27
+ ):
28
+ if not isinstance(split_patterns, list):
29
+ self.split_patterns = [split_patterns]
30
+ else:
31
+ self.split_patterns = split_patterns
32
+
33
+ if remove_patterns is not None \
34
+ and not isinstance(remove_patterns, list):
35
+ self.remove_patterns = [remove_patterns]
36
+ else:
37
+ self.remove_patterns = remove_patterns
38
+
39
+ self.group_splits = group_splits
40
+ self.remove_too_short_groups = remove_too_short_groups
41
+
42
+ def split(self, text, tokenizer, split_patterns=None):
43
+ if split_patterns is None:
44
+ if self.split_patterns is None:
45
+ return [text]
46
+ split_patterns = self.split_patterns
47
+
48
+ def len_in_tokens(text_):
49
+ no_tokens = len(tokenizer.encode(text_, add_special_tokens=False))
50
+ return no_tokens
51
+
52
+ no_special_tokens = len(tokenizer.encode('', add_special_tokens=True))
53
+ max_tokens = tokenizer.max_len - no_special_tokens
54
+
55
+ if self.remove_patterns is not None:
56
+ for remove_pattern in self.remove_patterns:
57
+ text = re.sub(remove_pattern, '', text).strip()
58
+
59
+ if len_in_tokens(text) <= max_tokens:
60
+ return [text]
61
+
62
+ selected_splits = []
63
+ splits = map(lambda x: x.strip(), re.findall(split_patterns[0], text))
64
+
65
+ aggregated_splits = ''
66
+ for split in splits:
67
+ if len_in_tokens(split) > max_tokens:
68
+ if len(split_patterns) > 1:
69
+ sub_splits = self.split(
70
+ split, tokenizer, split_patterns[1:])
71
+ selected_splits.extend(sub_splits)
72
+ else:
73
+ selected_splits.append(split)
74
+
75
+ else:
76
+ if not self.group_splits:
77
+ selected_splits.append(split)
78
+ else:
79
+ new_aggregated_splits = \
80
+ f'{aggregated_splits} {split}'.strip()
81
+ if len_in_tokens(new_aggregated_splits) <= max_tokens:
82
+ aggregated_splits = new_aggregated_splits
83
+ else:
84
+ selected_splits.append(aggregated_splits)
85
+ aggregated_splits = split
86
+
87
+ if aggregated_splits:
88
+ selected_splits.append(aggregated_splits)
89
+
90
+ remove_too_short_groups = len(selected_splits) > 1 \
91
+ and self.group_splits \
92
+ and self.remove_too_short_groups
93
+
94
+ if not remove_too_short_groups:
95
+ final_splits = selected_splits
96
+ else:
97
+ final_splits = []
98
+ min_length = tokenizer.max_len / 2
99
+ for split in selected_splits:
100
+ if len_in_tokens(split) >= min_length:
101
+ final_splits.append(split)
102
+
103
+ return final_splits
104
+
105
+
106
+ class SplitStrategies:
107
+ SentencesWithoutUrls = SplitStrategy(split_patterns=[
108
+ RegexExpressions.split_by_dot,
109
+ RegexExpressions.split_by_semicolon,
110
+ RegexExpressions.split_by_colon,
111
+ RegexExpressions.split_by_comma
112
+ ],
113
+ remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
114
+ remove_too_short_groups=False,
115
+ group_splits=False)
116
+
117
+ GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[
118
+ RegexExpressions.split_by_dot,
119
+ RegexExpressions.split_by_semicolon,
120
+ RegexExpressions.split_by_colon,
121
+ RegexExpressions.split_by_comma
122
+ ],
123
+ remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
124
+ remove_too_short_groups=True,
125
+ group_splits=True)