bwang0911 commited on
Commit
88ee741
·
verified ·
1 Parent(s): ac0b6c5

fix: add missing lora adaptations (#8)

Browse files

- fix: add missing lora adaptations (d37db3b7651d0109b04fa81b4a29783eec7cda12)

Files changed (1) hide show
  1. custom_st.py +9 -0
custom_st.py CHANGED
@@ -56,6 +56,15 @@ class Transformer(nn.Module):
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
  self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
58
 
 
 
 
 
 
 
 
 
 
59
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
60
  tokenizer_args["model_max_length"] = max_seq_length
61
  self.tokenizer = AutoTokenizer.from_pretrained(
 
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
  self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
58
 
59
+ self._lora_adaptations = config.lora_adaptations
60
+ if (
61
+ not isinstance(self._lora_adaptations, list)
62
+ or len(self._lora_adaptations) < 1
63
+ ):
64
+ raise ValueError(
65
+ f"`lora_adaptations` must be a list and contain at least one element"
66
+ )
67
+
68
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
69
  tokenizer_args["model_max_length"] = max_seq_length
70
  self.tokenizer = AutoTokenizer.from_pretrained(