Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
9ced7bd
1
Parent(s):
25f1183
Remove unused methods and improve caching
Browse files- src/model.py +22 -69
src/model.py
CHANGED
@@ -1,7 +1,6 @@
|
|
|
|
1 |
import pickle
|
2 |
import os
|
3 |
-
from shared import CustomTokens
|
4 |
-
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
|
5 |
from dataclasses import dataclass, field
|
6 |
from typing import Optional
|
7 |
|
@@ -15,28 +14,34 @@ class ModelArguments:
|
|
15 |
model_name_or_path: str = field(
|
16 |
default='google/t5-v1_1-small', # t5-small
|
17 |
metadata={
|
18 |
-
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
|
|
19 |
)
|
20 |
# config_name: Optional[str] = field( # TODO remove?
|
21 |
# default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
|
22 |
# )
|
23 |
tokenizer_name: Optional[str] = field(
|
24 |
-
default=None, metadata={
|
|
|
|
|
25 |
)
|
26 |
cache_dir: Optional[str] = field(
|
27 |
default=None,
|
28 |
metadata={
|
29 |
-
'help': 'Where to store the pretrained models downloaded from huggingface.co'
|
|
|
30 |
)
|
31 |
use_fast_tokenizer: bool = field( # TODO remove?
|
32 |
default=True,
|
33 |
metadata={
|
34 |
-
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
|
|
|
35 |
)
|
36 |
model_revision: str = field( # TODO remove?
|
37 |
default='main',
|
38 |
metadata={
|
39 |
-
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
|
|
|
40 |
)
|
41 |
use_auth_token: bool = field(
|
42 |
default=False,
|
@@ -53,68 +58,16 @@ class ModelArguments:
|
|
53 |
)
|
54 |
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
name = cached_path
|
63 |
-
|
64 |
-
config = AutoConfig.from_pretrained(
|
65 |
-
name,
|
66 |
-
cache_dir=model_args.cache_dir,
|
67 |
-
revision=model_args.model_revision,
|
68 |
-
use_auth_token=True if model_args.use_auth_token else None,
|
69 |
-
)
|
70 |
-
|
71 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
72 |
-
name,
|
73 |
-
from_tf='.ckpt' in name,
|
74 |
-
config=config,
|
75 |
-
cache_dir=model_args.cache_dir,
|
76 |
-
revision=model_args.model_revision,
|
77 |
-
use_auth_token=True if model_args.use_auth_token else None,
|
78 |
-
)
|
79 |
-
|
80 |
-
return model
|
81 |
-
|
82 |
-
|
83 |
-
def get_tokenizer(model_args, use_cache=True):
|
84 |
-
name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
|
85 |
-
|
86 |
-
cached_path = f'models/{name}'
|
87 |
-
|
88 |
-
if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
|
89 |
-
name = cached_path
|
90 |
-
|
91 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
92 |
-
name,
|
93 |
-
cache_dir=model_args.cache_dir,
|
94 |
-
use_fast=model_args.use_fast_tokenizer,
|
95 |
-
revision=model_args.model_revision,
|
96 |
-
use_auth_token=True if model_args.use_auth_token else None,
|
97 |
-
)
|
98 |
-
|
99 |
-
CustomTokens.add_custom_tokens(tokenizer)
|
100 |
-
|
101 |
-
return tokenizer
|
102 |
-
|
103 |
-
|
104 |
-
CLASSIFIER_CACHE = {}
|
105 |
-
def get_classifier_vectorizer(classifier_args, use_cache=True):
|
106 |
-
classifier_path = os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file)
|
107 |
-
if use_cache and classifier_path in CLASSIFIER_CACHE:
|
108 |
-
classifier = CLASSIFIER_CACHE[classifier_path]
|
109 |
-
else:
|
110 |
-
with open(classifier_path, 'rb') as fp:
|
111 |
-
classifier = CLASSIFIER_CACHE[classifier_path] = pickle.load(fp)
|
112 |
|
113 |
-
vectorizer_path = os.path.join(
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
with open(vectorizer_path, 'rb') as fp:
|
118 |
-
vectorizer = CLASSIFIER_CACHE[vectorizer_path] = pickle.load(fp)
|
119 |
|
120 |
return classifier, vectorizer
|
|
|
1 |
+
from functools import lru_cache
|
2 |
import pickle
|
3 |
import os
|
|
|
|
|
4 |
from dataclasses import dataclass, field
|
5 |
from typing import Optional
|
6 |
|
|
|
14 |
model_name_or_path: str = field(
|
15 |
default='google/t5-v1_1-small', # t5-small
|
16 |
metadata={
|
17 |
+
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
18 |
+
}
|
19 |
)
|
20 |
# config_name: Optional[str] = field( # TODO remove?
|
21 |
# default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
|
22 |
# )
|
23 |
tokenizer_name: Optional[str] = field(
|
24 |
+
default=None, metadata={
|
25 |
+
'help': 'Pretrained tokenizer name or path if not the same as model_name'
|
26 |
+
}
|
27 |
)
|
28 |
cache_dir: Optional[str] = field(
|
29 |
default=None,
|
30 |
metadata={
|
31 |
+
'help': 'Where to store the pretrained models downloaded from huggingface.co'
|
32 |
+
},
|
33 |
)
|
34 |
use_fast_tokenizer: bool = field( # TODO remove?
|
35 |
default=True,
|
36 |
metadata={
|
37 |
+
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
|
38 |
+
},
|
39 |
)
|
40 |
model_revision: str = field( # TODO remove?
|
41 |
default='main',
|
42 |
metadata={
|
43 |
+
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
|
44 |
+
},
|
45 |
)
|
46 |
use_auth_token: bool = field(
|
47 |
default=False,
|
|
|
58 |
)
|
59 |
|
60 |
|
61 |
+
@lru_cache
|
62 |
+
def get_classifier_vectorizer(classifier_args):
|
63 |
+
classifier_path = os.path.join(
|
64 |
+
classifier_args.classifier_dir, classifier_args.classifier_file)
|
65 |
+
with open(classifier_path, 'rb') as fp:
|
66 |
+
classifier = pickle.load(fp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
vectorizer_path = os.path.join(
|
69 |
+
classifier_args.classifier_dir, classifier_args.vectorizer_file)
|
70 |
+
with open(vectorizer_path, 'rb') as fp:
|
71 |
+
vectorizer = pickle.load(fp)
|
|
|
|
|
72 |
|
73 |
return classifier, vectorizer
|