Joshua Lochner commited on
Commit
bce5ce9
1 Parent(s): fca2a61

Create `get_model_tokenizer` helper method for loading model and tokenizer

Browse files
Files changed (1) hide show
  1. src/model.py +33 -6
src/model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from functools import lru_cache
2
  import pickle
3
  import os
@@ -12,7 +14,8 @@ class ModelArguments:
12
  """
13
 
14
  model_name_or_path: str = field(
15
- default='google/t5-v1_1-small', # t5-small
 
16
  metadata={
17
  'help': 'Path to pretrained model or model identifier from huggingface.co/models'
18
  }
@@ -20,11 +23,11 @@ class ModelArguments:
20
  # config_name: Optional[str] = field( # TODO remove?
21
  # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
22
  # )
23
- tokenizer_name: Optional[str] = field(
24
- default=None, metadata={
25
- 'help': 'Pretrained tokenizer name or path if not the same as model_name'
26
- }
27
- )
28
  cache_dir: Optional[str] = field(
29
  default=None,
30
  metadata={
@@ -71,3 +74,27 @@ def get_classifier_vectorizer(classifier_args):
71
  vectorizer = pickle.load(fp)
72
 
73
  return classifier, vectorizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ from shared import CustomTokens, device
3
  from functools import lru_cache
4
  import pickle
5
  import os
 
14
  """
15
 
16
  model_name_or_path: str = field(
17
+ default=None,
18
+ # default='google/t5-v1_1-small', # t5-small
19
  metadata={
20
  'help': 'Path to pretrained model or model identifier from huggingface.co/models'
21
  }
 
23
  # config_name: Optional[str] = field( # TODO remove?
24
  # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
25
  # )
26
+ # tokenizer_name: Optional[str] = field(
27
+ # default=None, metadata={
28
+ # 'help': 'Pretrained tokenizer name or path if not the same as model_name'
29
+ # }
30
+ # )
31
  cache_dir: Optional[str] = field(
32
  default=None,
33
  metadata={
 
74
  vectorizer = pickle.load(fp)
75
 
76
  return classifier, vectorizer
77
+
78
+
79
+ @lru_cache(maxsize=None)
80
+ def get_model_tokenizer(model_name_or_path, cache_dir=None):
81
+ if model_name_or_path is None:
82
+ raise ValueError('Invalid model_name_or_path.')
83
+
84
+ # Load pretrained model and tokenizer
85
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, cache_dir=cache_dir)
86
+ model.to(device())
87
+
88
+ tokenizer = AutoTokenizer.from_pretrained(
89
+ model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir)
90
+
91
+ # Ensure model and tokenizer contain the custom tokens
92
+ CustomTokens.add_custom_tokens(tokenizer)
93
+ model.resize_token_embeddings(len(tokenizer))
94
+
95
+ # TODO add this back: means that different models will have different training data
96
+ # Currently we only send 512 tokens to the model each time...
97
+ # Adjust based on dimensions of model
98
+ # tokenizer.model_max_length = model.config.d_model
99
+
100
+ return model, tokenizer