jupyterjazz commited on
Commit
42a68bc
1 Parent(s): 4a29e2c

Update custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +5 -2
custom_st.py CHANGED
@@ -51,8 +51,8 @@ class Transformer(nn.Module):
51
  if config_args is None:
52
  config_args = {}
53
 
 
54
  self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
55
- self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
56
 
57
  self._lora_adaptations = self.config.lora_adaptations
58
  if (
@@ -65,7 +65,10 @@ class Transformer(nn.Module):
65
  self._adaptation_map = {
66
  name: idx for idx, name in enumerate(self._lora_adaptations)
67
  }
68
- self._default_task = None
 
 
 
69
 
70
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
71
  tokenizer_args["model_max_length"] = max_seq_length
 
51
  if config_args is None:
52
  config_args = {}
53
 
54
+
55
  self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
 
56
 
57
  self._lora_adaptations = self.config.lora_adaptations
58
  if (
 
65
  self._adaptation_map = {
66
  name: idx for idx, name in enumerate(self._lora_adaptations)
67
  }
68
+
69
+ self.default_task = model_args.pop('default_task', None)
70
+
71
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
72
 
73
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
74
  tokenizer_args["model_max_length"] = max_seq_length