LightChen2333 commited on
Commit
b2c8fd0
·
1 Parent(s): cabed47

Upload model_manager.py

Browse files
Files changed (1) hide show
  1. common/model_manager.py +2 -3
common/model_manager.py CHANGED
@@ -2,7 +2,7 @@
2
  Author: Qiguang Chen
3
  Date: 2023-01-11 10:39:26
4
  LastEditors: Qiguang Chen
5
- LastEditTime: 2023-02-08 01:02:20
6
  Description: manage all process of model training and prediction.
7
 
8
  '''
@@ -287,8 +287,7 @@ class ModelManager(object):
287
  return outputs, res
288
 
289
  def load(self):
290
- # self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"), map_location=torch.device(self.device))
291
- self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"), map_location=torch.device("cpu"))
292
  if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer":
293
  self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file(
294
  os.path.join(self.config.base["model_dir"], "tokenizer.json"))
 
2
  Author: Qiguang Chen
3
  Date: 2023-01-11 10:39:26
4
  LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-02-08 00:57:09
6
  Description: manage all process of model training and prediction.
7
 
8
  '''
 
287
  return outputs, res
288
 
289
  def load(self):
290
+ self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"), map_location=torch.device(self.device))
 
291
  if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer":
292
  self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file(
293
  os.path.join(self.config.base["model_dir"], "tokenizer.json"))