Joshua Lochner commited on
Commit
9ced7bd
1 Parent(s): 25f1183

Remove unused methods and improve caching

Browse files
Files changed (1) hide show
  1. src/model.py +22 -69
src/model.py CHANGED
@@ -1,7 +1,6 @@
 
1
  import pickle
2
  import os
3
- from shared import CustomTokens
4
- from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
5
  from dataclasses import dataclass, field
6
  from typing import Optional
7
 
@@ -15,28 +14,34 @@ class ModelArguments:
15
  model_name_or_path: str = field(
16
  default='google/t5-v1_1-small', # t5-small
17
  metadata={
18
- 'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
 
19
  )
20
  # config_name: Optional[str] = field( # TODO remove?
21
  # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
22
  # )
23
  tokenizer_name: Optional[str] = field(
24
- default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
 
 
25
  )
26
  cache_dir: Optional[str] = field(
27
  default=None,
28
  metadata={
29
- 'help': 'Where to store the pretrained models downloaded from huggingface.co'},
 
30
  )
31
  use_fast_tokenizer: bool = field( # TODO remove?
32
  default=True,
33
  metadata={
34
- 'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
 
35
  )
36
  model_revision: str = field( # TODO remove?
37
  default='main',
38
  metadata={
39
- 'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
 
40
  )
41
  use_auth_token: bool = field(
42
  default=False,
@@ -53,68 +58,16 @@ class ModelArguments:
53
  )
54
 
55
 
56
- def get_model(model_args, use_cache=True):
57
- name = model_args.model_name_or_path
58
- cached_path = f'models/{name}'
59
-
60
- # Model created after tokenizer:
61
- if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
62
- name = cached_path
63
-
64
- config = AutoConfig.from_pretrained(
65
- name,
66
- cache_dir=model_args.cache_dir,
67
- revision=model_args.model_revision,
68
- use_auth_token=True if model_args.use_auth_token else None,
69
- )
70
-
71
- model = AutoModelForSeq2SeqLM.from_pretrained(
72
- name,
73
- from_tf='.ckpt' in name,
74
- config=config,
75
- cache_dir=model_args.cache_dir,
76
- revision=model_args.model_revision,
77
- use_auth_token=True if model_args.use_auth_token else None,
78
- )
79
-
80
- return model
81
-
82
-
83
- def get_tokenizer(model_args, use_cache=True):
84
- name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
85
-
86
- cached_path = f'models/{name}'
87
-
88
- if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
89
- name = cached_path
90
-
91
- tokenizer = AutoTokenizer.from_pretrained(
92
- name,
93
- cache_dir=model_args.cache_dir,
94
- use_fast=model_args.use_fast_tokenizer,
95
- revision=model_args.model_revision,
96
- use_auth_token=True if model_args.use_auth_token else None,
97
- )
98
-
99
- CustomTokens.add_custom_tokens(tokenizer)
100
-
101
- return tokenizer
102
-
103
-
104
- CLASSIFIER_CACHE = {}
105
- def get_classifier_vectorizer(classifier_args, use_cache=True):
106
- classifier_path = os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file)
107
- if use_cache and classifier_path in CLASSIFIER_CACHE:
108
- classifier = CLASSIFIER_CACHE[classifier_path]
109
- else:
110
- with open(classifier_path, 'rb') as fp:
111
- classifier = CLASSIFIER_CACHE[classifier_path] = pickle.load(fp)
112
 
113
- vectorizer_path = os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file)
114
- if use_cache and vectorizer_path in CLASSIFIER_CACHE:
115
- vectorizer = CLASSIFIER_CACHE[vectorizer_path]
116
- else:
117
- with open(vectorizer_path, 'rb') as fp:
118
- vectorizer = CLASSIFIER_CACHE[vectorizer_path] = pickle.load(fp)
119
 
120
  return classifier, vectorizer
 
1
+ from functools import lru_cache
2
  import pickle
3
  import os
 
 
4
  from dataclasses import dataclass, field
5
  from typing import Optional
6
 
 
14
  model_name_or_path: str = field(
15
  default='google/t5-v1_1-small', # t5-small
16
  metadata={
17
+ 'help': 'Path to pretrained model or model identifier from huggingface.co/models'
18
+ }
19
  )
20
  # config_name: Optional[str] = field( # TODO remove?
21
  # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
22
  # )
23
  tokenizer_name: Optional[str] = field(
24
+ default=None, metadata={
25
+ 'help': 'Pretrained tokenizer name or path if not the same as model_name'
26
+ }
27
  )
28
  cache_dir: Optional[str] = field(
29
  default=None,
30
  metadata={
31
+ 'help': 'Where to store the pretrained models downloaded from huggingface.co'
32
+ },
33
  )
34
  use_fast_tokenizer: bool = field( # TODO remove?
35
  default=True,
36
  metadata={
37
+ 'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
38
+ },
39
  )
40
  model_revision: str = field( # TODO remove?
41
  default='main',
42
  metadata={
43
+ 'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
44
+ },
45
  )
46
  use_auth_token: bool = field(
47
  default=False,
 
58
  )
59
 
60
 
61
+ @lru_cache
62
+ def get_classifier_vectorizer(classifier_args):
63
+ classifier_path = os.path.join(
64
+ classifier_args.classifier_dir, classifier_args.classifier_file)
65
+ with open(classifier_path, 'rb') as fp:
66
+ classifier = pickle.load(fp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ vectorizer_path = os.path.join(
69
+ classifier_args.classifier_dir, classifier_args.vectorizer_file)
70
+ with open(vectorizer_path, 'rb') as fp:
71
+ vectorizer = pickle.load(fp)
 
 
72
 
73
  return classifier, vectorizer