|
import yaml |
|
|
|
|
|
class Params: |
|
def __init__(self): |
|
self.graph_mode = "sequential" |
|
self.accumulation_steps = 1 |
|
self.activation = "relu" |
|
self.predict_intensity = False |
|
self.batch_size = 32 |
|
self.beta_2 = 0.98 |
|
self.blank_weight = 1.0 |
|
self.char_embedding = True |
|
self.char_embedding_size = 128 |
|
self.decoder_delay_steps = 0 |
|
self.decoder_learning_rate = 6e-4 |
|
self.decoder_weight_decay = 1.2e-6 |
|
self.dropout_anchor = 0.5 |
|
self.dropout_edge_label = 0.5 |
|
self.dropout_edge_presence = 0.5 |
|
self.dropout_label = 0.5 |
|
self.dropout_transformer = 0.5 |
|
self.dropout_transformer_attention = 0.1 |
|
self.dropout_word = 0.1 |
|
self.encoder = "xlm-roberta-base" |
|
self.encoder_delay_steps = 2000 |
|
self.encoder_freeze_embedding = True |
|
self.encoder_learning_rate = 6e-5 |
|
self.encoder_weight_decay = 1e-2 |
|
self.lr_decay_multiplier = 100 |
|
self.epochs = 100 |
|
self.focal = True |
|
self.freeze_bert = False |
|
self.group_ops = False |
|
self.hidden_size_ff = 4 * 768 |
|
self.hidden_size_anchor = 128 |
|
self.hidden_size_edge_label = 256 |
|
self.hidden_size_edge_presence = 512 |
|
self.layerwise_lr_decay = 1.0 |
|
self.n_attention_heads = 8 |
|
self.n_layers = 3 |
|
self.query_length = 4 |
|
self.pre_norm = True |
|
self.warmup_steps = 6000 |
|
|
|
def init_data_paths(self): |
|
directory_1 = { |
|
"sequential": "node_centric_mrp", |
|
"node-centric": "node_centric_mrp", |
|
"labeled-edge": "labeled_edge_mrp" |
|
}[self.graph_mode] |
|
directory_2 = { |
|
("darmstadt", "en"): "darmstadt_unis", |
|
("mpqa", "en"): "mpqa", |
|
("multibooked", "ca"): "multibooked_ca", |
|
("multibooked", "eu"): "multibooked_eu", |
|
("norec", "no"): "norec", |
|
("opener", "en"): "opener_en", |
|
("opener", "es"): "opener_es", |
|
}[(self.framework, self.language)] |
|
|
|
self.training_data = f"{self.data_directory}/{directory_1}/{directory_2}/train.mrp" |
|
self.validation_data = f"{self.data_directory}/{directory_1}/{directory_2}/dev.mrp" |
|
self.test_data = f"{self.data_directory}/{directory_1}/{directory_2}/test.mrp" |
|
|
|
self.raw_training_data = f"{self.data_directory}/raw/{directory_2}/train.json" |
|
self.raw_validation_data = f"{self.data_directory}/raw/{directory_2}/dev.json" |
|
|
|
return self |
|
|
|
def load_state_dict(self, d): |
|
for k, v in d.items(): |
|
setattr(self, k, v) |
|
return self |
|
|
|
def state_dict(self): |
|
members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")] |
|
return {k: self.__dict__[k] for k in members} |
|
|
|
def load(self, args): |
|
with open(args.config, "r", encoding="utf-8") as f: |
|
params = yaml.safe_load(f) |
|
self.load_state_dict(params) |
|
self.init_data_paths() |
|
|
|
def save(self, json_path): |
|
with open(json_path, "w", encoding="utf-8") as f: |
|
d = self.state_dict() |
|
yaml.dump(d, f) |
|
|