Joshua Lochner commited on
Commit
d7a594b
·
1 Parent(s): dffef09

Download classifier and vectorizer if not present

Browse files
Files changed (1) hide show
  1. src/model.py +20 -3
src/model.py CHANGED
@@ -1,5 +1,7 @@
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
  from shared import CustomTokens, device
 
3
  from functools import lru_cache
4
  import pickle
5
  import os
@@ -29,7 +31,7 @@ class ModelArguments:
29
  # }
30
  # )
31
  cache_dir: Optional[str] = field(
32
- default=None,
33
  metadata={
34
  'help': 'Where to store the pretrained models downloaded from huggingface.co'
35
  },
@@ -63,13 +65,27 @@ class ModelArguments:
63
 
64
  @lru_cache(maxsize=None)
65
  def get_classifier_vectorizer(classifier_args):
 
66
  classifier_path = os.path.join(
67
  classifier_args.classifier_dir, classifier_args.classifier_file)
 
 
 
 
 
 
68
  with open(classifier_path, 'rb') as fp:
69
  classifier = pickle.load(fp)
70
 
 
71
  vectorizer_path = os.path.join(
72
  classifier_args.classifier_dir, classifier_args.vectorizer_file)
 
 
 
 
 
 
73
  with open(vectorizer_path, 'rb') as fp:
74
  vectorizer = pickle.load(fp)
75
 
@@ -79,10 +95,11 @@ def get_classifier_vectorizer(classifier_args):
79
  @lru_cache(maxsize=None)
80
  def get_model_tokenizer(model_name_or_path, cache_dir=None):
81
  if model_name_or_path is None:
82
- raise ValueError('Invalid model_name_or_path.')
83
 
84
  # Load pretrained model and tokenizer
85
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, cache_dir=cache_dir)
 
86
  model.to(device())
87
 
88
  tokenizer = AutoTokenizer.from_pretrained(
 
1
+ from huggingface_hub import hf_hub_download
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  from shared import CustomTokens, device
4
+ from errors import ClassifierLoadError, ModelLoadError
5
  from functools import lru_cache
6
  import pickle
7
  import os
 
31
  # }
32
  # )
33
  cache_dir: Optional[str] = field(
34
+ default='models',
35
  metadata={
36
  'help': 'Where to store the pretrained models downloaded from huggingface.co'
37
  },
 
65
 
66
  @lru_cache(maxsize=None)
67
  def get_classifier_vectorizer(classifier_args):
68
+ # Classifier
69
  classifier_path = os.path.join(
70
  classifier_args.classifier_dir, classifier_args.classifier_file)
71
+ if not os.path.exists(classifier_path):
72
+ hf_hub_download(repo_id=classifier_args.classifier_model,
73
+ filename=classifier_args.classifier_file,
74
+ cache_dir=classifier_args.classifier_dir,
75
+ force_filename=classifier_args.classifier_file,
76
+ )
77
  with open(classifier_path, 'rb') as fp:
78
  classifier = pickle.load(fp)
79
 
80
+ # Vectorizer
81
  vectorizer_path = os.path.join(
82
  classifier_args.classifier_dir, classifier_args.vectorizer_file)
83
+ if not os.path.exists(vectorizer_path):
84
+ hf_hub_download(repo_id=classifier_args.classifier_model,
85
+ filename=classifier_args.vectorizer_file,
86
+ cache_dir=classifier_args.classifier_dir,
87
+ force_filename=classifier_args.vectorizer_file,
88
+ )
89
  with open(vectorizer_path, 'rb') as fp:
90
  vectorizer = pickle.load(fp)
91
 
 
95
  @lru_cache(maxsize=None)
96
  def get_model_tokenizer(model_name_or_path, cache_dir=None):
97
  if model_name_or_path is None:
98
+ raise ModelLoadError('Invalid model_name_or_path.')
99
 
100
  # Load pretrained model and tokenizer
101
+ model = AutoModelForSeq2SeqLM.from_pretrained(
102
+ model_name_or_path, cache_dir=cache_dir)
103
  model.to(device())
104
 
105
  tokenizer = AutoTokenizer.from_pretrained(