add update_config
Browse files- pipeline.py +9 -1
pipeline.py
CHANGED
@@ -62,7 +62,15 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
|
|
62 |
self.data_args.use_bilateral_predicate_marker = True
|
63 |
if "append_verb_form" not in vars(self.data_args):
|
64 |
self.data_args.append_verb_form = True
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def _sanitize_parameters(self, **kwargs):
|
68 |
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)
|
|
|
62 |
self.data_args.use_bilateral_predicate_marker = True
|
63 |
if "append_verb_form" not in vars(self.data_args):
|
64 |
self.data_args.append_verb_form = True
|
65 |
+
self._update_config(**kwargs)
|
66 |
+
|
67 |
+
def _update_config(self, **kwargs):
|
68 |
+
" Update self.model.config with initialization parameters and necessary defaults. "
|
69 |
+
# set default values that will always override model.config, but can overriden by __init__ kwargs
|
70 |
+
kwargs["max_length"] = kwargs.get("max_length", 80)
|
71 |
+
# override model.config with kwargs
|
72 |
+
for k,v in kwargs.items():
|
73 |
+
self.model.config.__dict__[k] = v
|
74 |
|
75 |
def _sanitize_parameters(self, **kwargs):
|
76 |
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)
|