jupyterjazz
commited on
Commit
•
42a68bc
1
Parent(s):
4a29e2c
Update custom_st.py
Browse files- custom_st.py +5 -2
custom_st.py
CHANGED
@@ -51,8 +51,8 @@ class Transformer(nn.Module):
|
|
51 |
if config_args is None:
|
52 |
config_args = {}
|
53 |
|
|
|
54 |
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
55 |
-
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
|
56 |
|
57 |
self._lora_adaptations = self.config.lora_adaptations
|
58 |
if (
|
@@ -65,7 +65,10 @@ class Transformer(nn.Module):
|
|
65 |
self._adaptation_map = {
|
66 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
67 |
}
|
68 |
-
|
|
|
|
|
|
|
69 |
|
70 |
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
|
71 |
tokenizer_args["model_max_length"] = max_seq_length
|
|
|
51 |
if config_args is None:
|
52 |
config_args = {}
|
53 |
|
54 |
+
|
55 |
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
|
|
56 |
|
57 |
self._lora_adaptations = self.config.lora_adaptations
|
58 |
if (
|
|
|
65 |
self._adaptation_map = {
|
66 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
67 |
}
|
68 |
+
|
69 |
+
self.default_task = model_args.pop('default_task', None)
|
70 |
+
|
71 |
+
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
|
72 |
|
73 |
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
|
74 |
tokenizer_args["model_max_length"] = max_seq_length
|